complete task query interface

This commit is contained in:
harry
2024-04-01 20:12:14 +08:00
parent 95bc24453f
commit 9283787681
10 changed files with 139 additions and 23 deletions

View File

@@ -46,6 +46,10 @@ def get_application() -> FastAPI:
app = get_application() app = get_application()
task_dir = utils.task_dir()
app.mount("/tasks", StaticFiles(directory=task_dir, html=True, follow_symlink=True), name="")
public_dir = utils.public_dir() public_dir = utils.public_dir()
app.mount("/", StaticFiles(directory=public_dir, html=True), name="") app.mount("/", StaticFiles(directory=public_dir, html=True), name="")

View File

@@ -9,6 +9,7 @@ if not os.path.isfile(config_file):
example_file = f"{root_dir}/config.example.toml" example_file = f"{root_dir}/config.example.toml"
if os.path.isfile(example_file): if os.path.isfile(example_file):
import shutil import shutil
shutil.copyfile(example_file, config_file) shutil.copyfile(example_file, config_file)
logger.info(f"copy config.example.toml to config.toml") logger.info(f"copy config.example.toml to config.toml")
@@ -27,8 +28,9 @@ log_level = _cfg.get("log_level", "DEBUG")
listen_host = _cfg.get("listen_host", "0.0.0.0") listen_host = _cfg.get("listen_host", "0.0.0.0")
listen_port = _cfg.get("listen_port", 8080) listen_port = _cfg.get("listen_port", 8080)
project_name = _cfg.get("project_name", "MoneyPrinterTurbo") project_name = _cfg.get("project_name", "MoneyPrinterTurbo")
project_description = _cfg.get("project_description", "MoneyPrinterTurbo\n by 抖音-网旭哈瑞.AI") project_description = _cfg.get("project_description",
project_version = _cfg.get("project_version", "1.0.0") "<a href='https://github.com/harry0703/MoneyPrinterTurbo'>https://github.com/harry0703/MoneyPrinterTurbo</a>")
project_version = _cfg.get("project_version", "1.0.1")
reload_debug = False reload_debug = False
imagemagick_path = app.get("imagemagick_path", "") imagemagick_path = app.get("imagemagick_path", "")

View File

@@ -1,13 +1,13 @@
from os import path from fastapi import Request, Depends, Path, BackgroundTasks
from fastapi import Request, Depends, Path
from loguru import logger from loguru import logger
from app.config import config
from app.controllers import base from app.controllers import base
from app.controllers.v1.base import new_router from app.controllers.v1.base import new_router
from app.models.exception import HttpException from app.models.exception import HttpException
from app.models.schema import TaskVideoRequest, TaskQueryResponse, TaskResponse, TaskQueryRequest from app.models.schema import TaskVideoRequest, TaskQueryResponse, TaskResponse, TaskQueryRequest
from app.services import task as tm from app.services import task as tm
from app.services import state as sm
from app.utils import utils from app.utils import utils
# 认证依赖项 # 认证依赖项
@@ -15,30 +15,43 @@ from app.utils import utils
router = new_router() router = new_router()
@router.post("/videos", response_model=TaskResponse, summary="使用主题来生成短视频") @router.post("/videos", response_model=TaskResponse, summary="Generate a short video")
def create_video(request: Request, body: TaskVideoRequest): def create_video(background_tasks: BackgroundTasks, request: Request, body: TaskVideoRequest):
task_id = utils.get_uuid() task_id = utils.get_uuid()
request_id = base.get_task_id(request) request_id = base.get_task_id(request)
try: try:
task = { task = {
"task_id": task_id, "task_id": task_id,
"request_id": request_id, "request_id": request_id,
"params": body.dict(),
} }
body_dict = body.dict() sm.update_task(task_id)
task.update(body_dict) background_tasks.add_task(tm.start, task_id=task_id, params=body)
result = tm.start(task_id=task_id, params=body)
task["result"] = result
logger.success(f"video created: {utils.to_json(task)}") logger.success(f"video created: {utils.to_json(task)}")
return utils.get_response(200, task) return utils.get_response(200, task)
except ValueError as e: except ValueError as e:
raise HttpException(task_id=task_id, status_code=400, message=f"{request_id}: {str(e)}") raise HttpException(task_id=task_id, status_code=400, message=f"{request_id}: {str(e)}")
@router.get("/tasks/{task_id}", response_model=TaskQueryResponse, summary="查询任务状态") @router.get("/tasks/{task_id}", response_model=TaskQueryResponse, summary="Query task status")
def get_task(request: Request, task_id: str = Path(..., description="任务ID"), def get_task(request: Request, task_id: str = Path(..., description="Task ID"),
query: TaskQueryRequest = Depends()): query: TaskQueryRequest = Depends()):
endpoint = config.app.get("endpoint", "")
if not endpoint:
endpoint = str(request.base_url)
endpoint = endpoint.rstrip("/")
request_id = base.get_task_id(request) request_id = base.get_task_id(request)
data = query.dict() task = sm.get_task(task_id)
data["task_id"] = task_id if task:
raise HttpException(task_id=task_id, status_code=404, if "videos" in task:
message=f"{request_id}: task not found", data=data) videos = task["videos"]
task_dir = utils.task_dir()
urls = []
for v in videos:
uri_path = v.replace(task_dir, "tasks")
urls.append(f"{endpoint}/{uri_path}")
task["videos"] = urls
return utils.get_response(200, task)
raise HttpException(task_id=task_id, status_code=404, message=f"{request_id}: task not found")

View File

@@ -1,4 +1,8 @@
punctuations = [ PUNCTUATIONS = [
"?", ",", ".", "", ";", ":", "?", ",", ".", "", ";", ":",
"", "", "", "", "", "", "", "", "", "", "", "",
] ]
TASK_STATE_FAILED = -1
TASK_STATE_COMPLETE = 1
TASK_STATE_PROCESSING = 4

View File

@@ -136,7 +136,6 @@ class TaskQueryRequest(BaseModel):
class TaskResponse(BaseResponse): class TaskResponse(BaseResponse):
class TaskResponseData(BaseModel): class TaskResponseData(BaseModel):
task_id: str task_id: str
task_type: str = ""
data: TaskResponseData data: TaskResponseData

35
app/services/state.py Normal file
View File

@@ -0,0 +1,35 @@
# State Management
# This module is responsible for managing the state of the application.
import math
# 如果你部署在分布式环境中,你可能需要一个中心化的状态管理服务,比如 Redis 或者数据库。
# 如果你的应用程序是单机的,你可以使用内存来存储状态。
# If you are deploying in a distributed environment, you might need a centralized state management service like Redis or a database.
# If your application is single-node, you can use memory to store the state.
from app.models import const
from app.utils import utils
_tasks = {}
def update_task(task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs):
"""
Set the state of the task.
"""
progress = int(progress)
if progress > 100:
progress = 100
_tasks[task_id] = {
"state": state,
"progress": progress,
**kwargs,
}
def get_task(task_id: str):
"""
Get the state of the task.
"""
return _tasks.get(task_id, None)

View File

@@ -6,8 +6,10 @@ from os import path
from loguru import logger from loguru import logger
from app.config import config from app.config import config
from app.models import const
from app.models.schema import VideoParams, VideoConcatMode from app.models.schema import VideoParams, VideoConcatMode
from app.services import llm, material, voice, video, subtitle from app.services import llm, material, voice, video, subtitle
from app.services import state as sm
from app.utils import utils from app.utils import utils
@@ -26,6 +28,8 @@ def start(task_id, params: VideoParams):
} }
""" """
logger.info(f"start task: {task_id}") logger.info(f"start task: {task_id}")
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=5)
video_subject = params.video_subject video_subject = params.video_subject
voice_name = voice.parse_voice_name(params.voice_name) voice_name = voice.parse_voice_name(params.voice_name)
paragraph_number = params.paragraph_number paragraph_number = params.paragraph_number
@@ -40,6 +44,8 @@ def start(task_id, params: VideoParams):
else: else:
logger.debug(f"video script: \n{video_script}") logger.debug(f"video script: \n{video_script}")
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=10)
logger.info("\n\n## generating video terms") logger.info("\n\n## generating video terms")
video_terms = params.video_terms video_terms = params.video_terms
if not video_terms: if not video_terms:
@@ -63,10 +69,13 @@ def start(task_id, params: VideoParams):
with open(script_file, "w", encoding="utf-8") as f: with open(script_file, "w", encoding="utf-8") as f:
f.write(utils.to_json(script_data)) f.write(utils.to_json(script_data))
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=20)
logger.info("\n\n## generating audio") logger.info("\n\n## generating audio")
audio_file = path.join(utils.task_dir(task_id), f"audio.mp3") audio_file = path.join(utils.task_dir(task_id), f"audio.mp3")
sub_maker = voice.tts(text=video_script, voice_name=voice_name, voice_file=audio_file) sub_maker = voice.tts(text=video_script, voice_name=voice_name, voice_file=audio_file)
if sub_maker is None: if sub_maker is None:
sm.update_task(task_id, state=const.TASK_STATE_FAILED)
logger.error( logger.error(
"failed to generate audio, maybe the network is not available. if you are in China, please use a VPN.") "failed to generate audio, maybe the network is not available. if you are in China, please use a VPN.")
return return
@@ -74,6 +83,8 @@ def start(task_id, params: VideoParams):
audio_duration = voice.get_audio_duration(sub_maker) audio_duration = voice.get_audio_duration(sub_maker)
audio_duration = math.ceil(audio_duration) audio_duration = math.ceil(audio_duration)
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=30)
subtitle_path = "" subtitle_path = ""
if params.subtitle_enabled: if params.subtitle_enabled:
subtitle_path = path.join(utils.task_dir(task_id), f"subtitle.srt") subtitle_path = path.join(utils.task_dir(task_id), f"subtitle.srt")
@@ -101,6 +112,8 @@ def start(task_id, params: VideoParams):
logger.warning(f"subtitle file is invalid: {subtitle_path}") logger.warning(f"subtitle file is invalid: {subtitle_path}")
subtitle_path = "" subtitle_path = ""
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=40)
logger.info("\n\n## downloading videos") logger.info("\n\n## downloading videos")
downloaded_videos = material.download_videos(task_id=task_id, downloaded_videos = material.download_videos(task_id=task_id,
search_terms=video_terms, search_terms=video_terms,
@@ -110,15 +123,19 @@ def start(task_id, params: VideoParams):
max_clip_duration=max_clip_duration, max_clip_duration=max_clip_duration,
) )
if not downloaded_videos: if not downloaded_videos:
sm.update_task(task_id, state=const.TASK_STATE_FAILED)
logger.error( logger.error(
"failed to download videos, maybe the network is not available. if you are in China, please use a VPN.") "failed to download videos, maybe the network is not available. if you are in China, please use a VPN.")
return return
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=50)
final_video_paths = [] final_video_paths = []
video_concat_mode = params.video_concat_mode video_concat_mode = params.video_concat_mode
if params.video_count > 1: if params.video_count > 1:
video_concat_mode = VideoConcatMode.random video_concat_mode = VideoConcatMode.random
_progress = 50
for i in range(params.video_count): for i in range(params.video_count):
index = i + 1 index = i + 1
combined_video_path = path.join(utils.task_dir(task_id), f"combined-{index}.mp4") combined_video_path = path.join(utils.task_dir(task_id), f"combined-{index}.mp4")
@@ -131,6 +148,9 @@ def start(task_id, params: VideoParams):
max_clip_duration=max_clip_duration, max_clip_duration=max_clip_duration,
threads=n_threads) threads=n_threads)
_progress += 50 / params.video_count / 2
sm.update_task(task_id, progress=_progress)
final_video_path = path.join(utils.task_dir(task_id), f"final-{index}.mp4") final_video_path = path.join(utils.task_dir(task_id), f"final-{index}.mp4")
logger.info(f"\n\n## generating video: {index} => {final_video_path}") logger.info(f"\n\n## generating video: {index} => {final_video_path}")
@@ -141,10 +161,16 @@ def start(task_id, params: VideoParams):
output_file=final_video_path, output_file=final_video_path,
params=params, params=params,
) )
_progress += 50 / params.video_count / 2
sm.update_task(task_id, progress=_progress)
final_video_paths.append(final_video_path) final_video_paths.append(final_video_path)
logger.success(f"task {task_id} finished, generated {len(final_video_paths)} videos.") logger.success(f"task {task_id} finished, generated {len(final_video_paths)} videos.")
return { kwargs = {
"videos": final_video_paths, "videos": final_video_paths,
} }
sm.update_task(task_id, state=const.TASK_STATE_COMPLETE, progress=100, **kwargs)
return kwargs

View File

@@ -149,7 +149,7 @@ def text_to_srt(idx: int, msg: str, start_time: float, end_time: float) -> str:
def str_contains_punctuation(word): def str_contains_punctuation(word):
for p in const.punctuations: for p in const.PUNCTUATIONS:
if p in word: if p in word:
return True return True
return False return False
@@ -159,7 +159,7 @@ def split_string_by_punctuations(s):
result = [] result = []
txt = "" txt = ""
for char in s: for char in s:
if char not in const.punctuations: if char not in const.PUNCTUATIONS:
txt += char txt += char
else: else:
result.append(txt.strip()) result.append(txt.strip())

View File

@@ -97,6 +97,20 @@
# ffmpeg_path = "C:\\Users\\harry\\Downloads\\ffmpeg.exe" # ffmpeg_path = "C:\\Users\\harry\\Downloads\\ffmpeg.exe"
######################################################################################### #########################################################################################
# 当视频生成成功后API服务提供的视频下载接入点默认为当前服务的地址和监听端口
# 比如 http://127.0.0.1:8080/tasks/6357f542-a4e1-46a1-b4c9-bf3bd0df5285/final-1.mp4
# 如果你需要使用域名对外提供服务一般会用nginx做代理则可以设置为你的域名
# 比如 https://xxxx.com/tasks/6357f542-a4e1-46a1-b4c9-bf3bd0df5285/final-1.mp4
# endpoint="https://xxxx.com"
# When the video is successfully generated, the API service provides a download endpoint for the video, defaulting to the service's current address and listening port.
# For example, http://127.0.0.1:8080/tasks/6357f542-a4e1-46a1-b4c9-bf3bd0df5285/final-1.mp4
# If you need to provide the service externally using a domain name (usually done with nginx as a proxy), you can set it to your domain name.
# For example, https://xxxx.com/tasks/6357f542-a4e1-46a1-b4c9-bf3bd0df5285/final-1.mp4
# endpoint="https://xxxx.com"
endpoint=""
[whisper] [whisper]
# Only effective when subtitle_provider is "whisper" # Only effective when subtitle_provider is "whisper"

View File

@@ -0,0 +1,19 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>MoneyPrinterTurbo</title>
</head>
<body>
<h1>MoneyPrinterTurbo</h1>
<a href="https://github.com/harry0703/MoneyPrinterTurbo">https://github.com/harry0703/MoneyPrinterTurbo</a>
<p>
只需提供一个视频 主题 或 关键词 ,就可以全自动生成视频文案、视频素材、视频字幕、视频背景音乐,然后合成一个高清的短视频。
</p>
<p>
Simply provide a topic or keyword for a video, and it will automatically generate the video copy, video materials,
video subtitles, and video background music before synthesizing a high-definition short video.
</p>
</body>
</html>