complete task query interface

This commit is contained in:
harry
2024-04-01 20:12:14 +08:00
parent 95bc24453f
commit 9283787681
10 changed files with 139 additions and 23 deletions

View File

@@ -46,6 +46,10 @@ def get_application() -> FastAPI:
app = get_application()
task_dir = utils.task_dir()
app.mount("/tasks", StaticFiles(directory=task_dir, html=True, follow_symlink=True), name="")
public_dir = utils.public_dir()
app.mount("/", StaticFiles(directory=public_dir, html=True), name="")

View File

@@ -9,6 +9,7 @@ if not os.path.isfile(config_file):
example_file = f"{root_dir}/config.example.toml"
if os.path.isfile(example_file):
import shutil
shutil.copyfile(example_file, config_file)
logger.info(f"copy config.example.toml to config.toml")
@@ -27,8 +28,9 @@ log_level = _cfg.get("log_level", "DEBUG")
listen_host = _cfg.get("listen_host", "0.0.0.0")
listen_port = _cfg.get("listen_port", 8080)
project_name = _cfg.get("project_name", "MoneyPrinterTurbo")
project_description = _cfg.get("project_description", "MoneyPrinterTurbo\n by 抖音-网旭哈瑞.AI")
project_version = _cfg.get("project_version", "1.0.0")
project_description = _cfg.get("project_description",
"<a href='https://github.com/harry0703/MoneyPrinterTurbo'>https://github.com/harry0703/MoneyPrinterTurbo</a>")
project_version = _cfg.get("project_version", "1.0.1")
reload_debug = False
imagemagick_path = app.get("imagemagick_path", "")

View File

@@ -1,13 +1,13 @@
from os import path
from fastapi import Request, Depends, Path
from fastapi import Request, Depends, Path, BackgroundTasks
from loguru import logger
from app.config import config
from app.controllers import base
from app.controllers.v1.base import new_router
from app.models.exception import HttpException
from app.models.schema import TaskVideoRequest, TaskQueryResponse, TaskResponse, TaskQueryRequest
from app.services import task as tm
from app.services import state as sm
from app.utils import utils
# 认证依赖项
@@ -15,30 +15,43 @@ from app.utils import utils
router = new_router()
@router.post("/videos", response_model=TaskResponse, summary="使用主题来生成短视频")
def create_video(request: Request, body: TaskVideoRequest):
@router.post("/videos", response_model=TaskResponse, summary="Generate a short video")
def create_video(background_tasks: BackgroundTasks, request: Request, body: TaskVideoRequest):
task_id = utils.get_uuid()
request_id = base.get_task_id(request)
try:
task = {
"task_id": task_id,
"request_id": request_id,
"params": body.dict(),
}
body_dict = body.dict()
task.update(body_dict)
result = tm.start(task_id=task_id, params=body)
task["result"] = result
sm.update_task(task_id)
background_tasks.add_task(tm.start, task_id=task_id, params=body)
logger.success(f"video created: {utils.to_json(task)}")
return utils.get_response(200, task)
except ValueError as e:
raise HttpException(task_id=task_id, status_code=400, message=f"{request_id}: {str(e)}")
@router.get("/tasks/{task_id}", response_model=TaskQueryResponse, summary="查询任务状态")
def get_task(request: Request, task_id: str = Path(..., description="任务ID"),
query: TaskQueryRequest = Depends()):
@router.get("/tasks/{task_id}", response_model=TaskQueryResponse, summary="Query task status")
def get_task(request: Request, task_id: str = Path(..., description="Task ID"),
query: TaskQueryRequest = Depends()):
endpoint = config.app.get("endpoint", "")
if not endpoint:
endpoint = str(request.base_url)
endpoint = endpoint.rstrip("/")
request_id = base.get_task_id(request)
data = query.dict()
data["task_id"] = task_id
raise HttpException(task_id=task_id, status_code=404,
message=f"{request_id}: task not found", data=data)
task = sm.get_task(task_id)
if task:
if "videos" in task:
videos = task["videos"]
task_dir = utils.task_dir()
urls = []
for v in videos:
uri_path = v.replace(task_dir, "tasks")
urls.append(f"{endpoint}/{uri_path}")
task["videos"] = urls
return utils.get_response(200, task)
raise HttpException(task_id=task_id, status_code=404, message=f"{request_id}: task not found")

View File

@@ -1,4 +1,8 @@
punctuations = [
PUNCTUATIONS = [
"?", ",", ".", "", ";", ":",
"", "", "", "", "", "",
]
TASK_STATE_FAILED = -1
TASK_STATE_COMPLETE = 1
TASK_STATE_PROCESSING = 4

View File

@@ -136,7 +136,6 @@ class TaskQueryRequest(BaseModel):
class TaskResponse(BaseResponse):
class TaskResponseData(BaseModel):
task_id: str
task_type: str = ""
data: TaskResponseData

35
app/services/state.py Normal file
View File

@@ -0,0 +1,35 @@
# State Management
# This module is responsible for managing the state of the application.
import math
# 如果你部署在分布式环境中,你可能需要一个中心化的状态管理服务,比如 Redis 或者数据库。
# 如果你的应用程序是单机的,你可以使用内存来存储状态。
# If you are deploying in a distributed environment, you might need a centralized state management service like Redis or a database.
# If your application is single-node, you can use memory to store the state.
from app.models import const
from app.utils import utils
_tasks = {}
def update_task(task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs):
"""
Set the state of the task.
"""
progress = int(progress)
if progress > 100:
progress = 100
_tasks[task_id] = {
"state": state,
"progress": progress,
**kwargs,
}
def get_task(task_id: str):
"""
Get the state of the task.
"""
return _tasks.get(task_id, None)

View File

@@ -6,8 +6,10 @@ from os import path
from loguru import logger
from app.config import config
from app.models import const
from app.models.schema import VideoParams, VideoConcatMode
from app.services import llm, material, voice, video, subtitle
from app.services import state as sm
from app.utils import utils
@@ -26,6 +28,8 @@ def start(task_id, params: VideoParams):
}
"""
logger.info(f"start task: {task_id}")
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=5)
video_subject = params.video_subject
voice_name = voice.parse_voice_name(params.voice_name)
paragraph_number = params.paragraph_number
@@ -40,6 +44,8 @@ def start(task_id, params: VideoParams):
else:
logger.debug(f"video script: \n{video_script}")
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=10)
logger.info("\n\n## generating video terms")
video_terms = params.video_terms
if not video_terms:
@@ -63,10 +69,13 @@ def start(task_id, params: VideoParams):
with open(script_file, "w", encoding="utf-8") as f:
f.write(utils.to_json(script_data))
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=20)
logger.info("\n\n## generating audio")
audio_file = path.join(utils.task_dir(task_id), f"audio.mp3")
sub_maker = voice.tts(text=video_script, voice_name=voice_name, voice_file=audio_file)
if sub_maker is None:
sm.update_task(task_id, state=const.TASK_STATE_FAILED)
logger.error(
"failed to generate audio, maybe the network is not available. if you are in China, please use a VPN.")
return
@@ -74,6 +83,8 @@ def start(task_id, params: VideoParams):
audio_duration = voice.get_audio_duration(sub_maker)
audio_duration = math.ceil(audio_duration)
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=30)
subtitle_path = ""
if params.subtitle_enabled:
subtitle_path = path.join(utils.task_dir(task_id), f"subtitle.srt")
@@ -101,6 +112,8 @@ def start(task_id, params: VideoParams):
logger.warning(f"subtitle file is invalid: {subtitle_path}")
subtitle_path = ""
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=40)
logger.info("\n\n## downloading videos")
downloaded_videos = material.download_videos(task_id=task_id,
search_terms=video_terms,
@@ -110,15 +123,19 @@ def start(task_id, params: VideoParams):
max_clip_duration=max_clip_duration,
)
if not downloaded_videos:
sm.update_task(task_id, state=const.TASK_STATE_FAILED)
logger.error(
"failed to download videos, maybe the network is not available. if you are in China, please use a VPN.")
return
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=50)
final_video_paths = []
video_concat_mode = params.video_concat_mode
if params.video_count > 1:
video_concat_mode = VideoConcatMode.random
_progress = 50
for i in range(params.video_count):
index = i + 1
combined_video_path = path.join(utils.task_dir(task_id), f"combined-{index}.mp4")
@@ -131,6 +148,9 @@ def start(task_id, params: VideoParams):
max_clip_duration=max_clip_duration,
threads=n_threads)
_progress += 50 / params.video_count / 2
sm.update_task(task_id, progress=_progress)
final_video_path = path.join(utils.task_dir(task_id), f"final-{index}.mp4")
logger.info(f"\n\n## generating video: {index} => {final_video_path}")
@@ -141,10 +161,16 @@ def start(task_id, params: VideoParams):
output_file=final_video_path,
params=params,
)
_progress += 50 / params.video_count / 2
sm.update_task(task_id, progress=_progress)
final_video_paths.append(final_video_path)
logger.success(f"task {task_id} finished, generated {len(final_video_paths)} videos.")
return {
kwargs = {
"videos": final_video_paths,
}
sm.update_task(task_id, state=const.TASK_STATE_COMPLETE, progress=100, **kwargs)
return kwargs

View File

@@ -149,7 +149,7 @@ def text_to_srt(idx: int, msg: str, start_time: float, end_time: float) -> str:
def str_contains_punctuation(word):
for p in const.punctuations:
for p in const.PUNCTUATIONS:
if p in word:
return True
return False
@@ -159,7 +159,7 @@ def split_string_by_punctuations(s):
result = []
txt = ""
for char in s:
if char not in const.punctuations:
if char not in const.PUNCTUATIONS:
txt += char
else:
result.append(txt.strip())