complete task query interface
This commit is contained in:
35
app/services/state.py
Normal file
35
app/services/state.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# State Management
|
||||
# This module is responsible for managing the state of the application.
|
||||
import math
|
||||
|
||||
# 如果你部署在分布式环境中,你可能需要一个中心化的状态管理服务,比如 Redis 或者数据库。
|
||||
# 如果你的应用程序是单机的,你可以使用内存来存储状态。
|
||||
|
||||
# If you are deploying in a distributed environment, you might need a centralized state management service like Redis or a database.
|
||||
# If your application is single-node, you can use memory to store the state.
|
||||
|
||||
from app.models import const
|
||||
from app.utils import utils
|
||||
|
||||
_tasks = {}
|
||||
|
||||
|
||||
def update_task(task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs):
|
||||
"""
|
||||
Set the state of the task.
|
||||
"""
|
||||
progress = int(progress)
|
||||
if progress > 100:
|
||||
progress = 100
|
||||
|
||||
_tasks[task_id] = {
|
||||
"state": state,
|
||||
"progress": progress,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
def get_task(task_id: str):
|
||||
"""
|
||||
Get the state of the task.
|
||||
"""
|
||||
return _tasks.get(task_id, None)
|
||||
@@ -6,8 +6,10 @@ from os import path
|
||||
from loguru import logger
|
||||
|
||||
from app.config import config
|
||||
from app.models import const
|
||||
from app.models.schema import VideoParams, VideoConcatMode
|
||||
from app.services import llm, material, voice, video, subtitle
|
||||
from app.services import state as sm
|
||||
from app.utils import utils
|
||||
|
||||
|
||||
@@ -26,6 +28,8 @@ def start(task_id, params: VideoParams):
|
||||
}
|
||||
"""
|
||||
logger.info(f"start task: {task_id}")
|
||||
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=5)
|
||||
|
||||
video_subject = params.video_subject
|
||||
voice_name = voice.parse_voice_name(params.voice_name)
|
||||
paragraph_number = params.paragraph_number
|
||||
@@ -40,6 +44,8 @@ def start(task_id, params: VideoParams):
|
||||
else:
|
||||
logger.debug(f"video script: \n{video_script}")
|
||||
|
||||
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=10)
|
||||
|
||||
logger.info("\n\n## generating video terms")
|
||||
video_terms = params.video_terms
|
||||
if not video_terms:
|
||||
@@ -63,10 +69,13 @@ def start(task_id, params: VideoParams):
|
||||
with open(script_file, "w", encoding="utf-8") as f:
|
||||
f.write(utils.to_json(script_data))
|
||||
|
||||
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=20)
|
||||
|
||||
logger.info("\n\n## generating audio")
|
||||
audio_file = path.join(utils.task_dir(task_id), f"audio.mp3")
|
||||
sub_maker = voice.tts(text=video_script, voice_name=voice_name, voice_file=audio_file)
|
||||
if sub_maker is None:
|
||||
sm.update_task(task_id, state=const.TASK_STATE_FAILED)
|
||||
logger.error(
|
||||
"failed to generate audio, maybe the network is not available. if you are in China, please use a VPN.")
|
||||
return
|
||||
@@ -74,6 +83,8 @@ def start(task_id, params: VideoParams):
|
||||
audio_duration = voice.get_audio_duration(sub_maker)
|
||||
audio_duration = math.ceil(audio_duration)
|
||||
|
||||
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=30)
|
||||
|
||||
subtitle_path = ""
|
||||
if params.subtitle_enabled:
|
||||
subtitle_path = path.join(utils.task_dir(task_id), f"subtitle.srt")
|
||||
@@ -101,6 +112,8 @@ def start(task_id, params: VideoParams):
|
||||
logger.warning(f"subtitle file is invalid: {subtitle_path}")
|
||||
subtitle_path = ""
|
||||
|
||||
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=40)
|
||||
|
||||
logger.info("\n\n## downloading videos")
|
||||
downloaded_videos = material.download_videos(task_id=task_id,
|
||||
search_terms=video_terms,
|
||||
@@ -110,15 +123,19 @@ def start(task_id, params: VideoParams):
|
||||
max_clip_duration=max_clip_duration,
|
||||
)
|
||||
if not downloaded_videos:
|
||||
sm.update_task(task_id, state=const.TASK_STATE_FAILED)
|
||||
logger.error(
|
||||
"failed to download videos, maybe the network is not available. if you are in China, please use a VPN.")
|
||||
return
|
||||
|
||||
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=50)
|
||||
|
||||
final_video_paths = []
|
||||
video_concat_mode = params.video_concat_mode
|
||||
if params.video_count > 1:
|
||||
video_concat_mode = VideoConcatMode.random
|
||||
|
||||
_progress = 50
|
||||
for i in range(params.video_count):
|
||||
index = i + 1
|
||||
combined_video_path = path.join(utils.task_dir(task_id), f"combined-{index}.mp4")
|
||||
@@ -131,6 +148,9 @@ def start(task_id, params: VideoParams):
|
||||
max_clip_duration=max_clip_duration,
|
||||
threads=n_threads)
|
||||
|
||||
_progress += 50 / params.video_count / 2
|
||||
sm.update_task(task_id, progress=_progress)
|
||||
|
||||
final_video_path = path.join(utils.task_dir(task_id), f"final-{index}.mp4")
|
||||
|
||||
logger.info(f"\n\n## generating video: {index} => {final_video_path}")
|
||||
@@ -141,10 +161,16 @@ def start(task_id, params: VideoParams):
|
||||
output_file=final_video_path,
|
||||
params=params,
|
||||
)
|
||||
|
||||
_progress += 50 / params.video_count / 2
|
||||
sm.update_task(task_id, progress=_progress)
|
||||
|
||||
final_video_paths.append(final_video_path)
|
||||
|
||||
logger.success(f"task {task_id} finished, generated {len(final_video_paths)} videos.")
|
||||
|
||||
return {
|
||||
kwargs = {
|
||||
"videos": final_video_paths,
|
||||
}
|
||||
sm.update_task(task_id, state=const.TASK_STATE_COMPLETE, progress=100, **kwargs)
|
||||
return kwargs
|
||||
|
||||
Reference in New Issue
Block a user