From 4a800eab4bebec6811433977f332997386e1552a Mon Sep 17 00:00:00 2001 From: harry Date: Mon, 18 Mar 2024 22:05:17 +0800 Subject: [PATCH] tts fallback --- app/services/subtitle.py | 9 ++++++--- app/services/task.py | 8 +++++++- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/app/services/subtitle.py b/app/services/subtitle.py index f86db9d..e62ce57 100644 --- a/app/services/subtitle.py +++ b/app/services/subtitle.py @@ -11,12 +11,15 @@ from app.utils import utils model_size = config.whisper.get("model_size", "large-v3") device = config.whisper.get("device", "cpu") compute_type = config.whisper.get("compute_type", "int8") - -if config.app.get("subtitle_provider") == "whisper": - model = WhisperModel(model_size_or_path=model_size, device=device, compute_type=compute_type) +model = None def create(audio_file, subtitle_file: str = ""): + global model + if not model: + logger.info(f"loading model: {model_size}, device: {device}, compute_type: {compute_type}") + model = WhisperModel(model_size_or_path=model_size, device=device, compute_type=compute_type) + logger.info(f"start, output file: {subtitle_file}") if not subtitle_file: subtitle_file = f"{audio_file}.srt" diff --git a/app/services/task.py b/app/services/task.py index 079556b..1961608 100644 --- a/app/services/task.py +++ b/app/services/task.py @@ -1,3 +1,4 @@ +import os.path from os import path from loguru import logger @@ -64,9 +65,14 @@ def start(task_id, params: VideoParams): subtitle_provider = config.app.get("subtitle_provider", "").strip().lower() logger.info(f"\n\n## generating subtitle, provider: {subtitle_provider}") + subtitle_fallback = False if subtitle_provider == "edge": voice.create_subtitle(text=script, sub_maker=sub_maker, subtitle_file=subtitle_path) - if subtitle_provider == "whisper": + if not os.path.exists(subtitle_path): + subtitle_fallback = True + logger.warning("subtitle file not found, fallback to whisper") + + if subtitle_provider == "whisper" or subtitle_fallback: subtitle.create(audio_file=audio_file, subtitle_file=subtitle_path) logger.info("\n\n## correcting subtitle") subtitle.correct(subtitle_file=subtitle_path, video_script=script)