Format project code

This commit is contained in:
yyhhyyyyyy
2024-07-24 14:59:06 +08:00
parent bbd4e94941
commit 905841965a
18 changed files with 410 additions and 214 deletions

View File

@@ -23,18 +23,22 @@ def create(audio_file, subtitle_file: str = ""):
if not os.path.isdir(model_path) or not os.path.isfile(model_bin_file):
model_path = model_size
logger.info(f"loading model: {model_path}, device: {device}, compute_type: {compute_type}")
logger.info(
f"loading model: {model_path}, device: {device}, compute_type: {compute_type}"
)
try:
model = WhisperModel(model_size_or_path=model_path,
device=device,
compute_type=compute_type)
model = WhisperModel(
model_size_or_path=model_path, device=device, compute_type=compute_type
)
except Exception as e:
logger.error(f"failed to load model: {e} \n\n"
f"********************************************\n"
f"this may be caused by network issue. \n"
f"please download the model manually and put it in the 'models' folder. \n"
f"see [README.md FAQ](https://github.com/harry0703/MoneyPrinterTurbo) for more details.\n"
f"********************************************\n\n")
logger.error(
f"failed to load model: {e} \n\n"
f"********************************************\n"
f"this may be caused by network issue. \n"
f"please download the model manually and put it in the 'models' folder. \n"
f"see [README.md FAQ](https://github.com/harry0703/MoneyPrinterTurbo) for more details.\n"
f"********************************************\n\n"
)
return None
logger.info(f"start, output file: {subtitle_file}")
@@ -49,7 +53,9 @@ def create(audio_file, subtitle_file: str = ""):
vad_parameters=dict(min_silence_duration_ms=500),
)
logger.info(f"detected language: '{info.language}', probability: {info.language_probability:.2f}")
logger.info(
f"detected language: '{info.language}', probability: {info.language_probability:.2f}"
)
start = timer()
subtitles = []
@@ -62,11 +68,9 @@ def create(audio_file, subtitle_file: str = ""):
msg = "[%.2fs -> %.2fs] %s" % (seg_start, seg_end, seg_text)
logger.debug(msg)
subtitles.append({
"msg": seg_text,
"start_time": seg_start,
"end_time": seg_end
})
subtitles.append(
{"msg": seg_text, "start_time": seg_start, "end_time": seg_end}
)
for segment in segments:
words_idx = 0
@@ -119,7 +123,11 @@ def create(audio_file, subtitle_file: str = ""):
for subtitle in subtitles:
text = subtitle.get("msg")
if text:
lines.append(utils.text_to_srt(idx, text, subtitle.get("start_time"), subtitle.get("end_time")))
lines.append(
utils.text_to_srt(
idx, text, subtitle.get("start_time"), subtitle.get("end_time")
)
)
idx += 1
sub = "\n".join(lines) + "\n"
@@ -136,12 +144,12 @@ def file_to_subtitles(filename):
current_times = None
current_text = ""
index = 0
with open(filename, 'r', encoding="utf-8") as f:
with open(filename, "r", encoding="utf-8") as f:
for line in f:
times = re.findall("([0-9]*:[0-9]*:[0-9]*,[0-9]*)", line)
if times:
current_times = line
elif line.strip() == '' and current_times:
elif line.strip() == "" and current_times:
index += 1
times_texts.append((index, current_times.strip(), current_text.strip()))
current_times, current_text = None, ""
@@ -166,9 +174,10 @@ def levenshtein_distance(s1, s2):
substitutions = previous_row[j] + (c1 != c2)
current_row.append(min(insertions, deletions, substitutions))
previous_row = current_row
return previous_row[-1]
def similarity(a, b):
distance = levenshtein_distance(a.lower(), b.lower())
max_length = max(len(a), len(b))
@@ -194,26 +203,44 @@ def correct(subtitle_file, video_script):
subtitle_index += 1
else:
combined_subtitle = subtitle_line
start_time = subtitle_items[subtitle_index][1].split(' --> ')[0]
end_time = subtitle_items[subtitle_index][1].split(' --> ')[1]
start_time = subtitle_items[subtitle_index][1].split(" --> ")[0]
end_time = subtitle_items[subtitle_index][1].split(" --> ")[1]
next_subtitle_index = subtitle_index + 1
while next_subtitle_index < len(subtitle_items):
next_subtitle = subtitle_items[next_subtitle_index][2].strip()
if similarity(script_line, combined_subtitle + " " + next_subtitle) > similarity(script_line, combined_subtitle):
if similarity(
script_line, combined_subtitle + " " + next_subtitle
) > similarity(script_line, combined_subtitle):
combined_subtitle += " " + next_subtitle
end_time = subtitle_items[next_subtitle_index][1].split(' --> ')[1]
end_time = subtitle_items[next_subtitle_index][1].split(" --> ")[1]
next_subtitle_index += 1
else:
break
if similarity(script_line, combined_subtitle) > 0.8:
logger.warning(f"Merged/Corrected - Script: {script_line}, Subtitle: {combined_subtitle}")
new_subtitle_items.append((len(new_subtitle_items) + 1, f"{start_time} --> {end_time}", script_line))
logger.warning(
f"Merged/Corrected - Script: {script_line}, Subtitle: {combined_subtitle}"
)
new_subtitle_items.append(
(
len(new_subtitle_items) + 1,
f"{start_time} --> {end_time}",
script_line,
)
)
corrected = True
else:
logger.warning(f"Mismatch - Script: {script_line}, Subtitle: {combined_subtitle}")
new_subtitle_items.append((len(new_subtitle_items) + 1, f"{start_time} --> {end_time}", script_line))
logger.warning(
f"Mismatch - Script: {script_line}, Subtitle: {combined_subtitle}"
)
new_subtitle_items.append(
(
len(new_subtitle_items) + 1,
f"{start_time} --> {end_time}",
script_line,
)
)
corrected = True
script_index += 1
@@ -223,10 +250,22 @@ def correct(subtitle_file, video_script):
while script_index < len(script_lines):
logger.warning(f"Extra script line: {script_lines[script_index]}")
if subtitle_index < len(subtitle_items):
new_subtitle_items.append((len(new_subtitle_items) + 1, subtitle_items[subtitle_index][1], script_lines[script_index]))
new_subtitle_items.append(
(
len(new_subtitle_items) + 1,
subtitle_items[subtitle_index][1],
script_lines[script_index],
)
)
subtitle_index += 1
else:
new_subtitle_items.append((len(new_subtitle_items) + 1, "00:00:00,000 --> 00:00:00,000", script_lines[script_index]))
new_subtitle_items.append(
(
len(new_subtitle_items) + 1,
"00:00:00,000 --> 00:00:00,000",
script_lines[script_index],
)
)
script_index += 1
corrected = True