feat: add redis support for task state management

This commit is contained in:
kevin.zhang
2024-04-10 10:42:56 +08:00
parent a0944fa358
commit 3d45348662
5 changed files with 111 additions and 44 deletions

View File

@@ -29,7 +29,7 @@ def create_video(background_tasks: BackgroundTasks, request: Request, body: Task
"request_id": request_id, "request_id": request_id,
"params": body.dict(), "params": body.dict(),
} }
sm.update_task(task_id) sm.state.update_task(task_id)
background_tasks.add_task(tm.start, task_id=task_id, params=body) background_tasks.add_task(tm.start, task_id=task_id, params=body)
logger.success(f"video created: {utils.to_json(task)}") logger.success(f"video created: {utils.to_json(task)}")
return utils.get_response(200, task) return utils.get_response(200, task)
@@ -46,7 +46,7 @@ def get_task(request: Request, task_id: str = Path(..., description="Task ID"),
endpoint = endpoint.rstrip("/") endpoint = endpoint.rstrip("/")
request_id = base.get_task_id(request) request_id = base.get_task_id(request)
task = sm.get_task(task_id) task = sm.state.get_task(task_id)
if task: if task:
task_dir = utils.task_dir() task_dir = utils.task_dir()

View File

@@ -1,35 +1,96 @@
# State Management import ast
# This module is responsible for managing the state of the application. import json
import math from abc import ABC, abstractmethod
import redis
# 如果你部署在分布式环境中,你可能需要一个中心化的状态管理服务,比如 Redis 或者数据库。 from app.config import config
# 如果你的应用程序是单机的,你可以使用内存来存储状态。
# If you are deploying in a distributed environment, you might need a centralized state management service like Redis or a database.
# If your application is single-node, you can use memory to store the state.
from app.models import const from app.models import const
from app.utils import utils
_tasks = {}
def update_task(task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs): # Base class for state management
""" class BaseState(ABC):
Set the state of the task.
""" @abstractmethod
def update_task(self, task_id: str, state: int, progress: int = 0, **kwargs):
pass
@abstractmethod
def get_task(self, task_id: str):
pass
# Memory state management
class MemoryState(BaseState):
def __init__(self):
self._tasks = {}
def update_task(self, task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs):
progress = int(progress) progress = int(progress)
if progress > 100: if progress > 100:
progress = 100 progress = 100
_tasks[task_id] = { self._tasks[task_id] = {
"state": state, "state": state,
"progress": progress, "progress": progress,
**kwargs, **kwargs,
} }
def get_task(task_id: str): def get_task(self, task_id: str):
return self._tasks.get(task_id, None)
# Redis state management
class RedisState(BaseState):
def __init__(self, host='localhost', port=6379, db=0):
self._redis = redis.StrictRedis(host=host, port=port, db=db)
def update_task(self, task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs):
progress = int(progress)
if progress > 100:
progress = 100
fields = {
"state": state,
"progress": progress,
**kwargs,
}
for field, value in fields.items():
self._redis.hset(task_id, field, str(value))
def get_task(self, task_id: str):
task_data = self._redis.hgetall(task_id)
if not task_data:
return None
task = {key.decode('utf-8'): self._convert_to_original_type(value) for key, value in task_data.items()}
return task
@staticmethod
def _convert_to_original_type(value):
""" """
Get the state of the task. Convert the value from byte string to its original data type.
You can extend this method to handle other data types as needed.
""" """
return _tasks.get(task_id, None) value_str = value.decode('utf-8')
try:
# try to convert byte string array to list
return ast.literal_eval(value_str)
except (ValueError, SyntaxError):
pass
if value_str.isdigit():
return int(value_str)
# Add more conversions here if needed
return value_str
# Global state
_enable_redis = config.app.get("enable_redis", False)
_redis_host = config.app.get("redis_host", "localhost")
_redis_port = config.app.get("redis_port", 6379)
_redis_db = config.app.get("redis_db", 0)
state = RedisState(host=_redis_host, port=_redis_port, db=_redis_db) if _enable_redis else MemoryState()

View File

@@ -28,7 +28,7 @@ def start(task_id, params: VideoParams):
} }
""" """
logger.info(f"start task: {task_id}") logger.info(f"start task: {task_id}")
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=5) sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=5)
video_subject = params.video_subject video_subject = params.video_subject
voice_name = voice.parse_voice_name(params.voice_name) voice_name = voice.parse_voice_name(params.voice_name)
@@ -44,7 +44,7 @@ def start(task_id, params: VideoParams):
else: else:
logger.debug(f"video script: \n{video_script}") logger.debug(f"video script: \n{video_script}")
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=10) sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=10)
logger.info("\n\n## generating video terms") logger.info("\n\n## generating video terms")
video_terms = params.video_terms video_terms = params.video_terms
@@ -70,13 +70,13 @@ def start(task_id, params: VideoParams):
with open(script_file, "w", encoding="utf-8") as f: with open(script_file, "w", encoding="utf-8") as f:
f.write(utils.to_json(script_data)) f.write(utils.to_json(script_data))
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=20) sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=20)
logger.info("\n\n## generating audio") logger.info("\n\n## generating audio")
audio_file = path.join(utils.task_dir(task_id), f"audio.mp3") audio_file = path.join(utils.task_dir(task_id), f"audio.mp3")
sub_maker = voice.tts(text=video_script, voice_name=voice_name, voice_file=audio_file) sub_maker = voice.tts(text=video_script, voice_name=voice_name, voice_file=audio_file)
if sub_maker is None: if sub_maker is None:
sm.update_task(task_id, state=const.TASK_STATE_FAILED) sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
logger.error( logger.error(
"failed to generate audio, maybe the network is not available. if you are in China, please use a VPN.") "failed to generate audio, maybe the network is not available. if you are in China, please use a VPN.")
return return
@@ -84,7 +84,7 @@ def start(task_id, params: VideoParams):
audio_duration = voice.get_audio_duration(sub_maker) audio_duration = voice.get_audio_duration(sub_maker)
audio_duration = math.ceil(audio_duration) audio_duration = math.ceil(audio_duration)
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=30) sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=30)
subtitle_path = "" subtitle_path = ""
if params.subtitle_enabled: if params.subtitle_enabled:
@@ -108,7 +108,7 @@ def start(task_id, params: VideoParams):
logger.warning(f"subtitle file is invalid: {subtitle_path}") logger.warning(f"subtitle file is invalid: {subtitle_path}")
subtitle_path = "" subtitle_path = ""
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=40) sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=40)
logger.info("\n\n## downloading videos") logger.info("\n\n## downloading videos")
downloaded_videos = material.download_videos(task_id=task_id, downloaded_videos = material.download_videos(task_id=task_id,
@@ -119,12 +119,12 @@ def start(task_id, params: VideoParams):
max_clip_duration=max_clip_duration, max_clip_duration=max_clip_duration,
) )
if not downloaded_videos: if not downloaded_videos:
sm.update_task(task_id, state=const.TASK_STATE_FAILED) sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
logger.error( logger.error(
"failed to download videos, maybe the network is not available. if you are in China, please use a VPN.") "failed to download videos, maybe the network is not available. if you are in China, please use a VPN.")
return return
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=50) sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=50)
final_video_paths = [] final_video_paths = []
combined_video_paths = [] combined_video_paths = []
@@ -146,7 +146,7 @@ def start(task_id, params: VideoParams):
threads=n_threads) threads=n_threads)
_progress += 50 / params.video_count / 2 _progress += 50 / params.video_count / 2
sm.update_task(task_id, progress=_progress) sm.state.update_task(task_id, progress=_progress)
final_video_path = path.join(utils.task_dir(task_id), f"final-{index}.mp4") final_video_path = path.join(utils.task_dir(task_id), f"final-{index}.mp4")
@@ -160,7 +160,7 @@ def start(task_id, params: VideoParams):
) )
_progress += 50 / params.video_count / 2 _progress += 50 / params.video_count / 2
sm.update_task(task_id, progress=_progress) sm.state.update_task(task_id, progress=_progress)
final_video_paths.append(final_video_path) final_video_paths.append(final_video_path)
combined_video_paths.append(combined_video_path) combined_video_paths.append(combined_video_path)
@@ -171,5 +171,5 @@ def start(task_id, params: VideoParams):
"videos": final_video_paths, "videos": final_video_paths,
"combined_videos": combined_video_paths "combined_videos": combined_video_paths
} }
sm.update_task(task_id, state=const.TASK_STATE_COMPLETE, progress=100, **kwargs) sm.state.update_task(task_id, state=const.TASK_STATE_COMPLETE, progress=100, **kwargs)
return kwargs return kwargs

View File

@@ -129,6 +129,11 @@
material_directory = "" material_directory = ""
# Used for state management of the task
enable_redis = true
redis_host = "localhost"
redis_port = 6379
redis_db = 0
[whisper] [whisper]
# Only effective when subtitle_provider is "whisper" # Only effective when subtitle_provider is "whisper"

View File

@@ -16,3 +16,4 @@ g4f~=0.2.5.4
dashscope~=1.15.0 dashscope~=1.15.0
google.generativeai~=0.4.1 google.generativeai~=0.4.1
python-multipart~=0.0.9 python-multipart~=0.0.9
redis==5.0.3