Merge pull request #216 from KevinZhang19870314/main

feat: add redis support for task state management
This commit is contained in:
Harry
2024-04-10 13:41:13 +08:00
committed by GitHub
5 changed files with 111 additions and 44 deletions

View File

@@ -29,7 +29,7 @@ def create_video(background_tasks: BackgroundTasks, request: Request, body: Task
"request_id": request_id, "request_id": request_id,
"params": body.dict(), "params": body.dict(),
} }
sm.update_task(task_id) sm.state.update_task(task_id)
background_tasks.add_task(tm.start, task_id=task_id, params=body) background_tasks.add_task(tm.start, task_id=task_id, params=body)
logger.success(f"video created: {utils.to_json(task)}") logger.success(f"video created: {utils.to_json(task)}")
return utils.get_response(200, task) return utils.get_response(200, task)
@@ -46,7 +46,7 @@ def get_task(request: Request, task_id: str = Path(..., description="Task ID"),
endpoint = endpoint.rstrip("/") endpoint = endpoint.rstrip("/")
request_id = base.get_task_id(request) request_id = base.get_task_id(request)
task = sm.get_task(task_id) task = sm.state.get_task(task_id)
if task: if task:
task_dir = utils.task_dir() task_dir = utils.task_dir()

View File

@@ -1,35 +1,96 @@
# State Management import ast
# This module is responsible for managing the state of the application. import json
import math from abc import ABC, abstractmethod
import redis
# 如果你部署在分布式环境中,你可能需要一个中心化的状态管理服务,比如 Redis 或者数据库。 from app.config import config
# 如果你的应用程序是单机的,你可以使用内存来存储状态。
# If you are deploying in a distributed environment, you might need a centralized state management service like Redis or a database.
# If your application is single-node, you can use memory to store the state.
from app.models import const from app.models import const
from app.utils import utils
_tasks = {}
def update_task(task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs): # Base class for state management
""" class BaseState(ABC):
Set the state of the task.
"""
progress = int(progress)
if progress > 100:
progress = 100
_tasks[task_id] = { @abstractmethod
"state": state, def update_task(self, task_id: str, state: int, progress: int = 0, **kwargs):
"progress": progress, pass
**kwargs,
}
def get_task(task_id: str): @abstractmethod
""" def get_task(self, task_id: str):
Get the state of the task. pass
"""
return _tasks.get(task_id, None)
# Memory state management
class MemoryState(BaseState):
def __init__(self):
self._tasks = {}
def update_task(self, task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs):
progress = int(progress)
if progress > 100:
progress = 100
self._tasks[task_id] = {
"state": state,
"progress": progress,
**kwargs,
}
def get_task(self, task_id: str):
return self._tasks.get(task_id, None)
# Redis state management
class RedisState(BaseState):
def __init__(self, host='localhost', port=6379, db=0):
self._redis = redis.StrictRedis(host=host, port=port, db=db)
def update_task(self, task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs):
progress = int(progress)
if progress > 100:
progress = 100
fields = {
"state": state,
"progress": progress,
**kwargs,
}
for field, value in fields.items():
self._redis.hset(task_id, field, str(value))
def get_task(self, task_id: str):
task_data = self._redis.hgetall(task_id)
if not task_data:
return None
task = {key.decode('utf-8'): self._convert_to_original_type(value) for key, value in task_data.items()}
return task
@staticmethod
def _convert_to_original_type(value):
"""
Convert the value from byte string to its original data type.
You can extend this method to handle other data types as needed.
"""
value_str = value.decode('utf-8')
try:
# try to convert byte string array to list
return ast.literal_eval(value_str)
except (ValueError, SyntaxError):
pass
if value_str.isdigit():
return int(value_str)
# Add more conversions here if needed
return value_str
# Global state
_enable_redis = config.app.get("enable_redis", False)
_redis_host = config.app.get("redis_host", "localhost")
_redis_port = config.app.get("redis_port", 6379)
_redis_db = config.app.get("redis_db", 0)
state = RedisState(host=_redis_host, port=_redis_port, db=_redis_db) if _enable_redis else MemoryState()

View File

@@ -28,7 +28,7 @@ def start(task_id, params: VideoParams):
} }
""" """
logger.info(f"start task: {task_id}") logger.info(f"start task: {task_id}")
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=5) sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=5)
video_subject = params.video_subject video_subject = params.video_subject
voice_name = voice.parse_voice_name(params.voice_name) voice_name = voice.parse_voice_name(params.voice_name)
@@ -44,7 +44,7 @@ def start(task_id, params: VideoParams):
else: else:
logger.debug(f"video script: \n{video_script}") logger.debug(f"video script: \n{video_script}")
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=10) sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=10)
logger.info("\n\n## generating video terms") logger.info("\n\n## generating video terms")
video_terms = params.video_terms video_terms = params.video_terms
@@ -70,13 +70,13 @@ def start(task_id, params: VideoParams):
with open(script_file, "w", encoding="utf-8") as f: with open(script_file, "w", encoding="utf-8") as f:
f.write(utils.to_json(script_data)) f.write(utils.to_json(script_data))
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=20) sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=20)
logger.info("\n\n## generating audio") logger.info("\n\n## generating audio")
audio_file = path.join(utils.task_dir(task_id), f"audio.mp3") audio_file = path.join(utils.task_dir(task_id), f"audio.mp3")
sub_maker = voice.tts(text=video_script, voice_name=voice_name, voice_file=audio_file) sub_maker = voice.tts(text=video_script, voice_name=voice_name, voice_file=audio_file)
if sub_maker is None: if sub_maker is None:
sm.update_task(task_id, state=const.TASK_STATE_FAILED) sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
logger.error( logger.error(
"failed to generate audio, maybe the network is not available. if you are in China, please use a VPN.") "failed to generate audio, maybe the network is not available. if you are in China, please use a VPN.")
return return
@@ -84,7 +84,7 @@ def start(task_id, params: VideoParams):
audio_duration = voice.get_audio_duration(sub_maker) audio_duration = voice.get_audio_duration(sub_maker)
audio_duration = math.ceil(audio_duration) audio_duration = math.ceil(audio_duration)
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=30) sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=30)
subtitle_path = "" subtitle_path = ""
if params.subtitle_enabled: if params.subtitle_enabled:
@@ -108,7 +108,7 @@ def start(task_id, params: VideoParams):
logger.warning(f"subtitle file is invalid: {subtitle_path}") logger.warning(f"subtitle file is invalid: {subtitle_path}")
subtitle_path = "" subtitle_path = ""
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=40) sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=40)
logger.info("\n\n## downloading videos") logger.info("\n\n## downloading videos")
downloaded_videos = material.download_videos(task_id=task_id, downloaded_videos = material.download_videos(task_id=task_id,
@@ -119,12 +119,12 @@ def start(task_id, params: VideoParams):
max_clip_duration=max_clip_duration, max_clip_duration=max_clip_duration,
) )
if not downloaded_videos: if not downloaded_videos:
sm.update_task(task_id, state=const.TASK_STATE_FAILED) sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
logger.error( logger.error(
"failed to download videos, maybe the network is not available. if you are in China, please use a VPN.") "failed to download videos, maybe the network is not available. if you are in China, please use a VPN.")
return return
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=50) sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=50)
final_video_paths = [] final_video_paths = []
combined_video_paths = [] combined_video_paths = []
@@ -146,7 +146,7 @@ def start(task_id, params: VideoParams):
threads=n_threads) threads=n_threads)
_progress += 50 / params.video_count / 2 _progress += 50 / params.video_count / 2
sm.update_task(task_id, progress=_progress) sm.state.update_task(task_id, progress=_progress)
final_video_path = path.join(utils.task_dir(task_id), f"final-{index}.mp4") final_video_path = path.join(utils.task_dir(task_id), f"final-{index}.mp4")
@@ -160,7 +160,7 @@ def start(task_id, params: VideoParams):
) )
_progress += 50 / params.video_count / 2 _progress += 50 / params.video_count / 2
sm.update_task(task_id, progress=_progress) sm.state.update_task(task_id, progress=_progress)
final_video_paths.append(final_video_path) final_video_paths.append(final_video_path)
combined_video_paths.append(combined_video_path) combined_video_paths.append(combined_video_path)
@@ -171,5 +171,5 @@ def start(task_id, params: VideoParams):
"videos": final_video_paths, "videos": final_video_paths,
"combined_videos": combined_video_paths "combined_videos": combined_video_paths
} }
sm.update_task(task_id, state=const.TASK_STATE_COMPLETE, progress=100, **kwargs) sm.state.update_task(task_id, state=const.TASK_STATE_COMPLETE, progress=100, **kwargs)
return kwargs return kwargs

View File

@@ -129,6 +129,11 @@
material_directory = "" material_directory = ""
# Used for state management of the task
enable_redis = false
redis_host = "localhost"
redis_port = 6379
redis_db = 0
[whisper] [whisper]
# Only effective when subtitle_provider is "whisper" # Only effective when subtitle_provider is "whisper"

View File

@@ -15,4 +15,5 @@ pydantic~=2.6.3
g4f~=0.2.5.4 g4f~=0.2.5.4
dashscope~=1.15.0 dashscope~=1.15.0
google.generativeai~=0.4.1 google.generativeai~=0.4.1
python-multipart~=0.0.9 python-multipart~=0.0.9
redis==5.0.3