Merge pull request #216 from KevinZhang19870314/main
feat: add redis support for task state management
This commit is contained in:
@@ -29,7 +29,7 @@ def create_video(background_tasks: BackgroundTasks, request: Request, body: Task
|
|||||||
"request_id": request_id,
|
"request_id": request_id,
|
||||||
"params": body.dict(),
|
"params": body.dict(),
|
||||||
}
|
}
|
||||||
sm.update_task(task_id)
|
sm.state.update_task(task_id)
|
||||||
background_tasks.add_task(tm.start, task_id=task_id, params=body)
|
background_tasks.add_task(tm.start, task_id=task_id, params=body)
|
||||||
logger.success(f"video created: {utils.to_json(task)}")
|
logger.success(f"video created: {utils.to_json(task)}")
|
||||||
return utils.get_response(200, task)
|
return utils.get_response(200, task)
|
||||||
@@ -46,7 +46,7 @@ def get_task(request: Request, task_id: str = Path(..., description="Task ID"),
|
|||||||
endpoint = endpoint.rstrip("/")
|
endpoint = endpoint.rstrip("/")
|
||||||
|
|
||||||
request_id = base.get_task_id(request)
|
request_id = base.get_task_id(request)
|
||||||
task = sm.get_task(task_id)
|
task = sm.state.get_task(task_id)
|
||||||
if task:
|
if task:
|
||||||
task_dir = utils.task_dir()
|
task_dir = utils.task_dir()
|
||||||
|
|
||||||
|
|||||||
@@ -1,35 +1,96 @@
|
|||||||
# State Management
|
import ast
|
||||||
# This module is responsible for managing the state of the application.
|
import json
|
||||||
import math
|
from abc import ABC, abstractmethod
|
||||||
|
import redis
|
||||||
# 如果你部署在分布式环境中,你可能需要一个中心化的状态管理服务,比如 Redis 或者数据库。
|
from app.config import config
|
||||||
# 如果你的应用程序是单机的,你可以使用内存来存储状态。
|
|
||||||
|
|
||||||
# If you are deploying in a distributed environment, you might need a centralized state management service like Redis or a database.
|
|
||||||
# If your application is single-node, you can use memory to store the state.
|
|
||||||
|
|
||||||
from app.models import const
|
from app.models import const
|
||||||
from app.utils import utils
|
|
||||||
|
|
||||||
_tasks = {}
|
|
||||||
|
|
||||||
|
|
||||||
def update_task(task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs):
|
# Base class for state management
|
||||||
"""
|
class BaseState(ABC):
|
||||||
Set the state of the task.
|
|
||||||
"""
|
|
||||||
progress = int(progress)
|
|
||||||
if progress > 100:
|
|
||||||
progress = 100
|
|
||||||
|
|
||||||
_tasks[task_id] = {
|
@abstractmethod
|
||||||
"state": state,
|
def update_task(self, task_id: str, state: int, progress: int = 0, **kwargs):
|
||||||
"progress": progress,
|
pass
|
||||||
**kwargs,
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_task(task_id: str):
|
@abstractmethod
|
||||||
"""
|
def get_task(self, task_id: str):
|
||||||
Get the state of the task.
|
pass
|
||||||
"""
|
|
||||||
return _tasks.get(task_id, None)
|
|
||||||
|
# Memory state management
|
||||||
|
class MemoryState(BaseState):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._tasks = {}
|
||||||
|
|
||||||
|
def update_task(self, task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs):
|
||||||
|
progress = int(progress)
|
||||||
|
if progress > 100:
|
||||||
|
progress = 100
|
||||||
|
|
||||||
|
self._tasks[task_id] = {
|
||||||
|
"state": state,
|
||||||
|
"progress": progress,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_task(self, task_id: str):
|
||||||
|
return self._tasks.get(task_id, None)
|
||||||
|
|
||||||
|
|
||||||
|
# Redis state management
|
||||||
|
class RedisState(BaseState):
|
||||||
|
|
||||||
|
def __init__(self, host='localhost', port=6379, db=0):
|
||||||
|
self._redis = redis.StrictRedis(host=host, port=port, db=db)
|
||||||
|
|
||||||
|
def update_task(self, task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs):
|
||||||
|
progress = int(progress)
|
||||||
|
if progress > 100:
|
||||||
|
progress = 100
|
||||||
|
|
||||||
|
fields = {
|
||||||
|
"state": state,
|
||||||
|
"progress": progress,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
for field, value in fields.items():
|
||||||
|
self._redis.hset(task_id, field, str(value))
|
||||||
|
|
||||||
|
def get_task(self, task_id: str):
|
||||||
|
task_data = self._redis.hgetall(task_id)
|
||||||
|
if not task_data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
task = {key.decode('utf-8'): self._convert_to_original_type(value) for key, value in task_data.items()}
|
||||||
|
return task
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_to_original_type(value):
|
||||||
|
"""
|
||||||
|
Convert the value from byte string to its original data type.
|
||||||
|
You can extend this method to handle other data types as needed.
|
||||||
|
"""
|
||||||
|
value_str = value.decode('utf-8')
|
||||||
|
|
||||||
|
try:
|
||||||
|
# try to convert byte string array to list
|
||||||
|
return ast.literal_eval(value_str)
|
||||||
|
except (ValueError, SyntaxError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
if value_str.isdigit():
|
||||||
|
return int(value_str)
|
||||||
|
# Add more conversions here if needed
|
||||||
|
return value_str
|
||||||
|
|
||||||
|
|
||||||
|
# Global state
|
||||||
|
_enable_redis = config.app.get("enable_redis", False)
|
||||||
|
_redis_host = config.app.get("redis_host", "localhost")
|
||||||
|
_redis_port = config.app.get("redis_port", 6379)
|
||||||
|
_redis_db = config.app.get("redis_db", 0)
|
||||||
|
|
||||||
|
state = RedisState(host=_redis_host, port=_redis_port, db=_redis_db) if _enable_redis else MemoryState()
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ def start(task_id, params: VideoParams):
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
logger.info(f"start task: {task_id}")
|
logger.info(f"start task: {task_id}")
|
||||||
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=5)
|
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=5)
|
||||||
|
|
||||||
video_subject = params.video_subject
|
video_subject = params.video_subject
|
||||||
voice_name = voice.parse_voice_name(params.voice_name)
|
voice_name = voice.parse_voice_name(params.voice_name)
|
||||||
@@ -44,7 +44,7 @@ def start(task_id, params: VideoParams):
|
|||||||
else:
|
else:
|
||||||
logger.debug(f"video script: \n{video_script}")
|
logger.debug(f"video script: \n{video_script}")
|
||||||
|
|
||||||
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=10)
|
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=10)
|
||||||
|
|
||||||
logger.info("\n\n## generating video terms")
|
logger.info("\n\n## generating video terms")
|
||||||
video_terms = params.video_terms
|
video_terms = params.video_terms
|
||||||
@@ -70,13 +70,13 @@ def start(task_id, params: VideoParams):
|
|||||||
with open(script_file, "w", encoding="utf-8") as f:
|
with open(script_file, "w", encoding="utf-8") as f:
|
||||||
f.write(utils.to_json(script_data))
|
f.write(utils.to_json(script_data))
|
||||||
|
|
||||||
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=20)
|
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=20)
|
||||||
|
|
||||||
logger.info("\n\n## generating audio")
|
logger.info("\n\n## generating audio")
|
||||||
audio_file = path.join(utils.task_dir(task_id), f"audio.mp3")
|
audio_file = path.join(utils.task_dir(task_id), f"audio.mp3")
|
||||||
sub_maker = voice.tts(text=video_script, voice_name=voice_name, voice_file=audio_file)
|
sub_maker = voice.tts(text=video_script, voice_name=voice_name, voice_file=audio_file)
|
||||||
if sub_maker is None:
|
if sub_maker is None:
|
||||||
sm.update_task(task_id, state=const.TASK_STATE_FAILED)
|
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
|
||||||
logger.error(
|
logger.error(
|
||||||
"failed to generate audio, maybe the network is not available. if you are in China, please use a VPN.")
|
"failed to generate audio, maybe the network is not available. if you are in China, please use a VPN.")
|
||||||
return
|
return
|
||||||
@@ -84,7 +84,7 @@ def start(task_id, params: VideoParams):
|
|||||||
audio_duration = voice.get_audio_duration(sub_maker)
|
audio_duration = voice.get_audio_duration(sub_maker)
|
||||||
audio_duration = math.ceil(audio_duration)
|
audio_duration = math.ceil(audio_duration)
|
||||||
|
|
||||||
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=30)
|
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=30)
|
||||||
|
|
||||||
subtitle_path = ""
|
subtitle_path = ""
|
||||||
if params.subtitle_enabled:
|
if params.subtitle_enabled:
|
||||||
@@ -108,7 +108,7 @@ def start(task_id, params: VideoParams):
|
|||||||
logger.warning(f"subtitle file is invalid: {subtitle_path}")
|
logger.warning(f"subtitle file is invalid: {subtitle_path}")
|
||||||
subtitle_path = ""
|
subtitle_path = ""
|
||||||
|
|
||||||
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=40)
|
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=40)
|
||||||
|
|
||||||
logger.info("\n\n## downloading videos")
|
logger.info("\n\n## downloading videos")
|
||||||
downloaded_videos = material.download_videos(task_id=task_id,
|
downloaded_videos = material.download_videos(task_id=task_id,
|
||||||
@@ -119,12 +119,12 @@ def start(task_id, params: VideoParams):
|
|||||||
max_clip_duration=max_clip_duration,
|
max_clip_duration=max_clip_duration,
|
||||||
)
|
)
|
||||||
if not downloaded_videos:
|
if not downloaded_videos:
|
||||||
sm.update_task(task_id, state=const.TASK_STATE_FAILED)
|
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
|
||||||
logger.error(
|
logger.error(
|
||||||
"failed to download videos, maybe the network is not available. if you are in China, please use a VPN.")
|
"failed to download videos, maybe the network is not available. if you are in China, please use a VPN.")
|
||||||
return
|
return
|
||||||
|
|
||||||
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=50)
|
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=50)
|
||||||
|
|
||||||
final_video_paths = []
|
final_video_paths = []
|
||||||
combined_video_paths = []
|
combined_video_paths = []
|
||||||
@@ -146,7 +146,7 @@ def start(task_id, params: VideoParams):
|
|||||||
threads=n_threads)
|
threads=n_threads)
|
||||||
|
|
||||||
_progress += 50 / params.video_count / 2
|
_progress += 50 / params.video_count / 2
|
||||||
sm.update_task(task_id, progress=_progress)
|
sm.state.update_task(task_id, progress=_progress)
|
||||||
|
|
||||||
final_video_path = path.join(utils.task_dir(task_id), f"final-{index}.mp4")
|
final_video_path = path.join(utils.task_dir(task_id), f"final-{index}.mp4")
|
||||||
|
|
||||||
@@ -160,7 +160,7 @@ def start(task_id, params: VideoParams):
|
|||||||
)
|
)
|
||||||
|
|
||||||
_progress += 50 / params.video_count / 2
|
_progress += 50 / params.video_count / 2
|
||||||
sm.update_task(task_id, progress=_progress)
|
sm.state.update_task(task_id, progress=_progress)
|
||||||
|
|
||||||
final_video_paths.append(final_video_path)
|
final_video_paths.append(final_video_path)
|
||||||
combined_video_paths.append(combined_video_path)
|
combined_video_paths.append(combined_video_path)
|
||||||
@@ -171,5 +171,5 @@ def start(task_id, params: VideoParams):
|
|||||||
"videos": final_video_paths,
|
"videos": final_video_paths,
|
||||||
"combined_videos": combined_video_paths
|
"combined_videos": combined_video_paths
|
||||||
}
|
}
|
||||||
sm.update_task(task_id, state=const.TASK_STATE_COMPLETE, progress=100, **kwargs)
|
sm.state.update_task(task_id, state=const.TASK_STATE_COMPLETE, progress=100, **kwargs)
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|||||||
@@ -129,6 +129,11 @@
|
|||||||
|
|
||||||
material_directory = ""
|
material_directory = ""
|
||||||
|
|
||||||
|
# Used for state management of the task
|
||||||
|
enable_redis = false
|
||||||
|
redis_host = "localhost"
|
||||||
|
redis_port = 6379
|
||||||
|
redis_db = 0
|
||||||
|
|
||||||
[whisper]
|
[whisper]
|
||||||
# Only effective when subtitle_provider is "whisper"
|
# Only effective when subtitle_provider is "whisper"
|
||||||
|
|||||||
@@ -15,4 +15,5 @@ pydantic~=2.6.3
|
|||||||
g4f~=0.2.5.4
|
g4f~=0.2.5.4
|
||||||
dashscope~=1.15.0
|
dashscope~=1.15.0
|
||||||
google.generativeai~=0.4.1
|
google.generativeai~=0.4.1
|
||||||
python-multipart~=0.0.9
|
python-multipart~=0.0.9
|
||||||
|
redis==5.0.3
|
||||||
Reference in New Issue
Block a user