init
This commit is contained in:
167
app/services/subtitle.py
Normal file
167
app/services/subtitle.py
Normal file
@@ -0,0 +1,167 @@
|
||||
import json
|
||||
import re
|
||||
|
||||
from faster_whisper import WhisperModel
|
||||
from timeit import default_timer as timer
|
||||
from loguru import logger
|
||||
|
||||
from app.config import config
|
||||
from app.models import const
|
||||
from app.utils import utils
|
||||
|
||||
model_size = config.whisper.get("model_size", "large-v3")
|
||||
device = config.whisper.get("device", "cpu")
|
||||
compute_type = config.whisper.get("compute_type", "int8")
|
||||
|
||||
model = WhisperModel(model_size_or_path=model_size, device=device, compute_type=compute_type)
|
||||
|
||||
|
||||
def create(audio_file, subtitle_file: str = ""):
|
||||
logger.info(f"start, output file: {subtitle_file}")
|
||||
if not subtitle_file:
|
||||
subtitle_file = f"{audio_file}.srt"
|
||||
|
||||
segments, info = model.transcribe(
|
||||
audio_file,
|
||||
beam_size=5,
|
||||
word_timestamps=True,
|
||||
vad_filter=True,
|
||||
vad_parameters=dict(min_silence_duration_ms=500),
|
||||
)
|
||||
|
||||
logger.info(f"detected language: '{info.language}', probability: {info.language_probability:.2f}")
|
||||
|
||||
start = timer()
|
||||
subtitles = []
|
||||
|
||||
def recognized(seg_text, seg_start, seg_end):
|
||||
seg_text = seg_text.strip()
|
||||
if not seg_text:
|
||||
return
|
||||
|
||||
msg = "[%.2fs -> %.2fs] %s" % (seg_start, seg_end, seg_text)
|
||||
logger.debug(msg)
|
||||
|
||||
subtitles.append({
|
||||
"msg": seg_text,
|
||||
"start_time": seg_start,
|
||||
"end_time": seg_end
|
||||
})
|
||||
|
||||
for segment in segments:
|
||||
words_idx = 0
|
||||
words_len = len(segment.words)
|
||||
|
||||
seg_start = 0
|
||||
seg_end = 0
|
||||
seg_text = ""
|
||||
|
||||
if segment.words:
|
||||
is_segmented = False
|
||||
for word in segment.words:
|
||||
if not is_segmented:
|
||||
seg_start = word.start
|
||||
is_segmented = True
|
||||
|
||||
seg_end = word.end
|
||||
# 如果包含标点,则断句
|
||||
seg_text += word.word
|
||||
|
||||
if utils.str_contains_punctuation(word.word):
|
||||
# remove last char
|
||||
seg_text = seg_text[:-1]
|
||||
if not seg_text:
|
||||
continue
|
||||
|
||||
recognized(seg_text, seg_start, seg_end)
|
||||
|
||||
is_segmented = False
|
||||
seg_text = ""
|
||||
|
||||
if words_idx == 0 and segment.start < word.start:
|
||||
seg_start = word.start
|
||||
if words_idx == (words_len - 1) and segment.end > word.end:
|
||||
seg_end = word.end
|
||||
words_idx += 1
|
||||
|
||||
if not seg_text:
|
||||
continue
|
||||
|
||||
recognized(seg_text, seg_start, seg_end)
|
||||
|
||||
end = timer()
|
||||
|
||||
diff = end - start
|
||||
logger.info(f"complete, elapsed: {diff:.2f} s")
|
||||
|
||||
idx = 1
|
||||
lines = []
|
||||
for subtitle in subtitles:
|
||||
text = subtitle.get("msg")
|
||||
if text:
|
||||
lines.append(utils.text_to_srt(idx, text, subtitle.get("start_time"), subtitle.get("end_time")))
|
||||
idx += 1
|
||||
|
||||
sub = "\n".join(lines)
|
||||
with open(subtitle_file, "w") as f:
|
||||
f.write(sub)
|
||||
logger.info(f"subtitle file created: {subtitle_file}")
|
||||
|
||||
|
||||
def file_to_subtitles(filename):
|
||||
times_texts = []
|
||||
current_times = None
|
||||
current_text = ""
|
||||
index = 0
|
||||
with open(filename, 'r') as f:
|
||||
for line in f:
|
||||
times = re.findall("([0-9]*:[0-9]*:[0-9]*,[0-9]*)", line)
|
||||
if times:
|
||||
current_times = line
|
||||
elif line.strip() == '' and current_times:
|
||||
index += 1
|
||||
times_texts.append((index, current_times.strip(), current_text.strip()))
|
||||
current_times, current_text = None, ""
|
||||
elif current_times:
|
||||
current_text += line
|
||||
return times_texts
|
||||
|
||||
|
||||
def correct(subtitle_file, video_script):
|
||||
subtitle_items = file_to_subtitles(subtitle_file)
|
||||
script_lines = utils.split_string_by_punctuations(video_script)
|
||||
|
||||
corrected = False
|
||||
if len(subtitle_items) == len(script_lines):
|
||||
for i in range(len(script_lines)):
|
||||
script_line = script_lines[i].strip()
|
||||
subtitle_line = subtitle_items[i][2]
|
||||
if script_line != subtitle_line:
|
||||
logger.warning(f"line {i + 1}, script: {script_line}, subtitle: {subtitle_line}")
|
||||
subtitle_items[i] = (subtitle_items[i][0], subtitle_items[i][1], script_line)
|
||||
corrected = True
|
||||
|
||||
if corrected:
|
||||
with open(subtitle_file, "w") as fd:
|
||||
for item in subtitle_items:
|
||||
fd.write(f"{item[0]}\n{item[1]}\n{item[2]}\n\n")
|
||||
logger.info(f"subtitle corrected")
|
||||
else:
|
||||
logger.success(f"subtitle is correct")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
task_id = "c12fd1e6-4b0a-4d65-a075-c87abe35a072"
|
||||
task_dir = utils.task_dir(task_id)
|
||||
subtitle_file = f"{task_dir}/subtitle.srt"
|
||||
|
||||
subtitles = file_to_subtitles(subtitle_file)
|
||||
print(subtitles)
|
||||
|
||||
script_file = f"{task_dir}/script.json"
|
||||
with open(script_file, "r") as f:
|
||||
script_content = f.read()
|
||||
s = json.loads(script_content)
|
||||
script = s.get("script")
|
||||
|
||||
correct(subtitle_file, script)
|
||||
Reference in New Issue
Block a user