1.Refactor task.py to encapsulate separable functions.

2.Add a new subtitle API.
This commit is contained in:
yyhhyyyyyy
2024-07-23 17:00:23 +08:00
parent 6d520a4266
commit 17b4a61e64
3 changed files with 383 additions and 194 deletions

View File

@@ -1,11 +1,12 @@
import os
import glob import glob
import os
import pathlib import pathlib
import shutil import shutil
from typing import Union
from fastapi import Request, Depends, Path, BackgroundTasks, UploadFile from fastapi import BackgroundTasks, Depends, Path, Request, UploadFile
from fastapi.responses import FileResponse, StreamingResponse
from fastapi.params import File from fastapi.params import File
from fastapi.responses import FileResponse, StreamingResponse
from loguru import logger from loguru import logger
from app.config import config from app.config import config
@@ -14,10 +15,19 @@ from app.controllers.manager.memory_manager import InMemoryTaskManager
from app.controllers.manager.redis_manager import RedisTaskManager from app.controllers.manager.redis_manager import RedisTaskManager
from app.controllers.v1.base import new_router from app.controllers.v1.base import new_router
from app.models.exception import HttpException from app.models.exception import HttpException
from app.models.schema import TaskVideoRequest, TaskQueryResponse, TaskResponse, TaskQueryRequest, \ from app.models.schema import (
BgmUploadResponse, BgmRetrieveResponse, TaskDeletionResponse AudioRequest,
from app.services import task as tm BgmRetrieveResponse,
BgmUploadResponse,
SubtitleRequest,
TaskDeletionResponse,
TaskQueryRequest,
TaskQueryResponse,
TaskResponse,
TaskVideoRequest,
)
from app.services import state as sm from app.services import state as sm
from app.services import task as tm
from app.utils import utils from app.utils import utils
# 认证依赖项 # 认证依赖项
@@ -34,48 +44,65 @@ _max_concurrent_tasks = config.app.get("max_concurrent_tasks", 5)
redis_url = f"redis://:{_redis_password}@{_redis_host}:{_redis_port}/{_redis_db}" redis_url = f"redis://:{_redis_password}@{_redis_host}:{_redis_port}/{_redis_db}"
# 根据配置选择合适的任务管理器 # 根据配置选择合适的任务管理器
if _enable_redis: if _enable_redis:
task_manager = RedisTaskManager(max_concurrent_tasks=_max_concurrent_tasks, redis_url=redis_url) task_manager = RedisTaskManager(
max_concurrent_tasks=_max_concurrent_tasks, redis_url=redis_url
)
else: else:
task_manager = InMemoryTaskManager(max_concurrent_tasks=_max_concurrent_tasks) task_manager = InMemoryTaskManager(max_concurrent_tasks=_max_concurrent_tasks)
# @router.post("/videos-test", response_model=TaskResponse, summary="Generate a short video")
# async def create_video_test(request: Request, body: TaskVideoRequest):
# task_id = utils.get_uuid()
# request_id = base.get_task_id(request)
# try:
# task = {
# "task_id": task_id,
# "request_id": request_id,
# "params": body.dict(),
# }
# task_manager.add_task(tm.start_test, task_id=task_id, params=body)
# return utils.get_response(200, task)
# except ValueError as e:
# raise HttpException(task_id=task_id, status_code=400, message=f"{request_id}: {str(e)}")
@router.post("/videos", response_model=TaskResponse, summary="Generate a short video") @router.post("/videos", response_model=TaskResponse, summary="Generate a short video")
def create_video(background_tasks: BackgroundTasks, request: Request, body: TaskVideoRequest): def create_video(
background_tasks: BackgroundTasks, request: Request, body: TaskVideoRequest
):
return create_task(request, body, stop_at="video")
@router.post("/subtitle", response_model=TaskResponse, summary="Generate subtitle only")
def create_subtitle(
background_tasks: BackgroundTasks, request: Request, body: SubtitleRequest
):
return create_task(request, body, stop_at="subtitle")
@router.post("/audio", response_model=TaskResponse, summary="Generate audio only")
def create_audio(
background_tasks: BackgroundTasks, request: Request, body: AudioRequest
):
return create_task(request, body, stop_at="audio")
def create_task(
request: Request,
body: Union[TaskVideoRequest, SubtitleRequest, AudioRequest],
stop_at: str,
):
task_id = utils.get_uuid() task_id = utils.get_uuid()
request_id = base.get_task_id(request) request_id = base.get_task_id(request)
try: try:
task = { task = {
"task_id": task_id, "task_id": task_id,
"request_id": request_id, "request_id": request_id,
"params": body.dict(), "params": body.model_dump(),
} }
sm.state.update_task(task_id) sm.state.update_task(task_id)
# background_tasks.add_task(tm.start, task_id=task_id, params=body) task_manager.add_task(tm.start, task_id=task_id, params=body, stop_at=stop_at)
task_manager.add_task(tm.start, task_id=task_id, params=body) logger.success(f"Task created: {utils.to_json(task)}")
logger.success(f"video created: {utils.to_json(task)}")
return utils.get_response(200, task) return utils.get_response(200, task)
except ValueError as e: except ValueError as e:
raise HttpException(task_id=task_id, status_code=400, message=f"{request_id}: {str(e)}") raise HttpException(
task_id=task_id, status_code=400, message=f"{request_id}: {str(e)}"
)
@router.get("/tasks/{task_id}", response_model=TaskQueryResponse, summary="Query task status") @router.get(
def get_task(request: Request, task_id: str = Path(..., description="Task ID"), "/tasks/{task_id}", response_model=TaskQueryResponse, summary="Query task status"
query: TaskQueryRequest = Depends()): )
def get_task(
request: Request,
task_id: str = Path(..., description="Task ID"),
query: TaskQueryRequest = Depends(),
):
endpoint = config.app.get("endpoint", "") endpoint = config.app.get("endpoint", "")
if not endpoint: if not endpoint:
endpoint = str(request.base_url) endpoint = str(request.base_url)
@@ -108,10 +135,16 @@ def get_task(request: Request, task_id: str = Path(..., description="Task ID"),
task["combined_videos"] = urls task["combined_videos"] = urls
return utils.get_response(200, task) return utils.get_response(200, task)
raise HttpException(task_id=task_id, status_code=404, message=f"{request_id}: task not found") raise HttpException(
task_id=task_id, status_code=404, message=f"{request_id}: task not found"
)
@router.delete("/tasks/{task_id}", response_model=TaskDeletionResponse, summary="Delete a generated short video task") @router.delete(
"/tasks/{task_id}",
response_model=TaskDeletionResponse,
summary="Delete a generated short video task",
)
def delete_video(request: Request, task_id: str = Path(..., description="Task ID")): def delete_video(request: Request, task_id: str = Path(..., description="Task ID")):
request_id = base.get_task_id(request) request_id = base.get_task_id(request)
task = sm.state.get_task(task_id) task = sm.state.get_task(task_id)
@@ -125,32 +158,40 @@ def delete_video(request: Request, task_id: str = Path(..., description="Task ID
logger.success(f"video deleted: {utils.to_json(task)}") logger.success(f"video deleted: {utils.to_json(task)}")
return utils.get_response(200) return utils.get_response(200)
raise HttpException(task_id=task_id, status_code=404, message=f"{request_id}: task not found") raise HttpException(
task_id=task_id, status_code=404, message=f"{request_id}: task not found"
)
@router.get("/musics", response_model=BgmRetrieveResponse, summary="Retrieve local BGM files") @router.get(
"/musics", response_model=BgmRetrieveResponse, summary="Retrieve local BGM files"
)
def get_bgm_list(request: Request): def get_bgm_list(request: Request):
suffix = "*.mp3" suffix = "*.mp3"
song_dir = utils.song_dir() song_dir = utils.song_dir()
files = glob.glob(os.path.join(song_dir, suffix)) files = glob.glob(os.path.join(song_dir, suffix))
bgm_list = [] bgm_list = []
for file in files: for file in files:
bgm_list.append({ bgm_list.append(
"name": os.path.basename(file), {
"size": os.path.getsize(file), "name": os.path.basename(file),
"file": file, "size": os.path.getsize(file),
}) "file": file,
response = { }
"files": bgm_list )
} response = {"files": bgm_list}
return utils.get_response(200, response) return utils.get_response(200, response)
@router.post("/musics", response_model=BgmUploadResponse, summary="Upload the BGM file to the songs directory") @router.post(
"/musics",
response_model=BgmUploadResponse,
summary="Upload the BGM file to the songs directory",
)
def upload_bgm_file(request: Request, file: UploadFile = File(...)): def upload_bgm_file(request: Request, file: UploadFile = File(...)):
request_id = base.get_task_id(request) request_id = base.get_task_id(request)
# check file ext # check file ext
if file.filename.endswith('mp3'): if file.filename.endswith("mp3"):
song_dir = utils.song_dir() song_dir = utils.song_dir()
save_path = os.path.join(song_dir, file.filename) save_path = os.path.join(song_dir, file.filename)
# save file # save file
@@ -158,26 +199,26 @@ def upload_bgm_file(request: Request, file: UploadFile = File(...)):
# If the file already exists, it will be overwritten # If the file already exists, it will be overwritten
file.file.seek(0) file.file.seek(0)
buffer.write(file.file.read()) buffer.write(file.file.read())
response = { response = {"file": save_path}
"file": save_path
}
return utils.get_response(200, response) return utils.get_response(200, response)
raise HttpException('', status_code=400, message=f"{request_id}: Only *.mp3 files can be uploaded") raise HttpException(
"", status_code=400, message=f"{request_id}: Only *.mp3 files can be uploaded"
)
@router.get("/stream/{file_path:path}") @router.get("/stream/{file_path:path}")
async def stream_video(request: Request, file_path: str): async def stream_video(request: Request, file_path: str):
tasks_dir = utils.task_dir() tasks_dir = utils.task_dir()
video_path = os.path.join(tasks_dir, file_path) video_path = os.path.join(tasks_dir, file_path)
range_header = request.headers.get('Range') range_header = request.headers.get("Range")
video_size = os.path.getsize(video_path) video_size = os.path.getsize(video_path)
start, end = 0, video_size - 1 start, end = 0, video_size - 1
length = video_size length = video_size
if range_header: if range_header:
range_ = range_header.split('bytes=')[1] range_ = range_header.split("bytes=")[1]
start, end = [int(part) if part else None for part in range_.split('-')] start, end = [int(part) if part else None for part in range_.split("-")]
if start is None: if start is None:
start = video_size - end start = video_size - end
end = video_size - 1 end = video_size - 1
@@ -186,7 +227,7 @@ async def stream_video(request: Request, file_path: str):
length = end - start + 1 length = end - start + 1
def file_iterator(file_path, offset=0, bytes_to_read=None): def file_iterator(file_path, offset=0, bytes_to_read=None):
with open(file_path, 'rb') as f: with open(file_path, "rb") as f:
f.seek(offset, os.SEEK_SET) f.seek(offset, os.SEEK_SET)
remaining = bytes_to_read or video_size remaining = bytes_to_read or video_size
while remaining > 0: while remaining > 0:
@@ -197,10 +238,12 @@ async def stream_video(request: Request, file_path: str):
remaining -= len(data) remaining -= len(data)
yield data yield data
response = StreamingResponse(file_iterator(video_path, start, length), media_type='video/mp4') response = StreamingResponse(
response.headers['Content-Range'] = f'bytes {start}-{end}/{video_size}' file_iterator(video_path, start, length), media_type="video/mp4"
response.headers['Accept-Ranges'] = 'bytes' )
response.headers['Content-Length'] = str(length) response.headers["Content-Range"] = f"bytes {start}-{end}/{video_size}"
response.headers["Accept-Ranges"] = "bytes"
response.headers["Content-Length"] = str(length)
response.status_code = 206 # Partial Content response.status_code = 206 # Partial Content
return response return response
@@ -219,8 +262,10 @@ async def download_video(_: Request, file_path: str):
file_path = pathlib.Path(video_path) file_path = pathlib.Path(video_path)
filename = file_path.stem filename = file_path.stem
extension = file_path.suffix extension = file_path.suffix
headers = { headers = {"Content-Disposition": f"attachment; filename={filename}{extension}"}
"Content-Disposition": f"attachment; filename={filename}{extension}" return FileResponse(
} path=video_path,
return FileResponse(path=video_path, headers=headers, filename=f"{filename}{extension}", headers=headers,
media_type=f'video/{extension[1:]}') filename=f"{filename}{extension}",
media_type=f"video/{extension[1:]}",
)

View File

@@ -1,12 +1,16 @@
import warnings
from enum import Enum from enum import Enum
from typing import Any, Optional, List from typing import Any, List, Optional
import pydantic import pydantic
from pydantic import BaseModel from pydantic import BaseModel
import warnings
# 忽略 Pydantic 的特定警告 # 忽略 Pydantic 的特定警告
warnings.filterwarnings("ignore", category=UserWarning, message="Field name.*shadows an attribute in parent.*") warnings.filterwarnings(
"ignore",
category=UserWarning,
message="Field name.*shadows an attribute in parent.*",
)
class VideoConcatMode(str, Enum): class VideoConcatMode(str, Enum):
@@ -61,7 +65,6 @@ class MaterialInfo:
# # "male-zh-TW-YunJheNeural", # # "male-zh-TW-YunJheNeural",
# #
# # en-US # # en-US
#
# "female-en-US-AnaNeural", # "female-en-US-AnaNeural",
# "female-en-US-AriaNeural", # "female-en-US-AriaNeural",
# "female-en-US-AvaNeural", # "female-en-US-AvaNeural",
@@ -93,6 +96,7 @@ class VideoParams(BaseModel):
"stroke_width": 1.5 "stroke_width": 1.5
} }
""" """
video_subject: str video_subject: str
video_script: str = "" # 用于生成视频的脚本 video_script: str = "" # 用于生成视频的脚本
video_terms: Optional[str | list] = None # 用于生成视频的关键词 video_terms: Optional[str | list] = None # 用于生成视频的关键词
@@ -126,6 +130,38 @@ class VideoParams(BaseModel):
paragraph_number: Optional[int] = 1 paragraph_number: Optional[int] = 1
class SubtitleRequest(BaseModel):
video_script: str
video_language: Optional[str] = ""
voice_name: Optional[str] = "zh-CN-XiaoxiaoNeural-Female"
voice_volume: Optional[float] = 1.0
voice_rate: Optional[float] = 1.2
bgm_type: Optional[str] = "random"
bgm_file: Optional[str] = ""
bgm_volume: Optional[float] = 0.2
subtitle_position: Optional[str] = "bottom"
font_name: Optional[str] = "STHeitiMedium.ttc"
text_fore_color: Optional[str] = "#FFFFFF"
text_background_color: Optional[str] = "transparent"
font_size: int = 60
stroke_color: Optional[str] = "#000000"
stroke_width: float = 1.5
video_source: Optional[str] = "local"
subtitle_enabled: Optional[str] = "true"
class AudioRequest(BaseModel):
video_script: str
video_language: Optional[str] = ""
voice_name: Optional[str] = "zh-CN-XiaoxiaoNeural-Female"
voice_volume: Optional[float] = 1.0
voice_rate: Optional[float] = 1.2
bgm_type: Optional[str] = "random"
bgm_file: Optional[str] = ""
bgm_volume: Optional[float] = 0.2
video_source: Optional[str] = "local"
class VideoScriptParams: class VideoScriptParams:
""" """
{ {
@@ -134,6 +170,7 @@ class VideoScriptParams:
"paragraph_number": 1 "paragraph_number": 1
} }
""" """
video_subject: Optional[str] = "春天的花海" video_subject: Optional[str] = "春天的花海"
video_language: Optional[str] = "" video_language: Optional[str] = ""
paragraph_number: Optional[int] = 1 paragraph_number: Optional[int] = 1
@@ -147,14 +184,17 @@ class VideoTermsParams:
"amount": 5 "amount": 5
} }
""" """
video_subject: Optional[str] = "春天的花海" video_subject: Optional[str] = "春天的花海"
video_script: Optional[str] = "春天的花海,如诗如画般展现在眼前。万物复苏的季节里,大地披上了一袭绚丽多彩的盛装。金黄的迎春、粉嫩的樱花、洁白的梨花、艳丽的郁金香……" video_script: Optional[str] = (
"春天的花海,如诗如画般展现在眼前。万物复苏的季节里,大地披上了一袭绚丽多彩的盛装。金黄的迎春、粉嫩的樱花、洁白的梨花、艳丽的郁金香……"
)
amount: Optional[int] = 5 amount: Optional[int] = 5
class BaseResponse(BaseModel): class BaseResponse(BaseModel):
status: int = 200 status: int = 200
message: Optional[str] = 'success' message: Optional[str] = "success"
data: Any = None data: Any = None
@@ -189,9 +229,7 @@ class TaskResponse(BaseResponse):
"example": { "example": {
"status": 200, "status": 200,
"message": "success", "message": "success",
"data": { "data": {"task_id": "6c85c8cc-a77a-42b9-bc30-947815aa0558"},
"task_id": "6c85c8cc-a77a-42b9-bc30-947815aa0558"
}
}, },
} }
@@ -210,8 +248,8 @@ class TaskQueryResponse(BaseResponse):
], ],
"combined_videos": [ "combined_videos": [
"http://127.0.0.1:8080/tasks/6c85c8cc-a77a-42b9-bc30-947815aa0558/combined-1.mp4" "http://127.0.0.1:8080/tasks/6c85c8cc-a77a-42b9-bc30-947815aa0558/combined-1.mp4"
] ],
} },
}, },
} }
@@ -230,8 +268,8 @@ class TaskDeletionResponse(BaseResponse):
], ],
"combined_videos": [ "combined_videos": [
"http://127.0.0.1:8080/tasks/6c85c8cc-a77a-42b9-bc30-947815aa0558/combined-1.mp4" "http://127.0.0.1:8080/tasks/6c85c8cc-a77a-42b9-bc30-947815aa0558/combined-1.mp4"
] ],
} },
}, },
} }
@@ -244,7 +282,7 @@ class VideoScriptResponse(BaseResponse):
"message": "success", "message": "success",
"data": { "data": {
"video_script": "春天的花海,是大自然的一幅美丽画卷。在这个季节里,大地复苏,万物生长,花朵争相绽放,形成了一片五彩斑斓的花海..." "video_script": "春天的花海,是大自然的一幅美丽画卷。在这个季节里,大地复苏,万物生长,花朵争相绽放,形成了一片五彩斑斓的花海..."
} },
}, },
} }
@@ -255,9 +293,7 @@ class VideoTermsResponse(BaseResponse):
"example": { "example": {
"status": 200, "status": 200,
"message": "success", "message": "success",
"data": { "data": {"video_terms": ["sky", "tree"]},
"video_terms": ["sky", "tree"]
}
}, },
} }
@@ -273,10 +309,10 @@ class BgmRetrieveResponse(BaseResponse):
{ {
"name": "output013.mp3", "name": "output013.mp3",
"size": 1891269, "size": 1891269,
"file": "/MoneyPrinterTurbo/resource/songs/output013.mp3" "file": "/MoneyPrinterTurbo/resource/songs/output013.mp3",
} }
] ]
} },
}, },
} }
@@ -287,8 +323,6 @@ class BgmUploadResponse(BaseResponse):
"example": { "example": {
"status": 200, "status": 200,
"message": "success", "message": "success",
"data": { "data": {"file": "/MoneyPrinterTurbo/resource/songs/example.mp3"},
"file": "/MoneyPrinterTurbo/resource/songs/example.mp3"
}
}, },
} }

View File

@@ -7,58 +7,42 @@ from loguru import logger
from app.config import config from app.config import config
from app.models import const from app.models import const
from app.models.schema import VideoParams, VideoConcatMode from app.models.schema import VideoConcatMode, VideoParams
from app.services import llm, material, voice, video, subtitle from app.services import llm, material, subtitle, video, voice
from app.services import state as sm from app.services import state as sm
from app.utils import utils from app.utils import utils
def start(task_id, params: VideoParams): def generate_script(task_id, params):
"""
{
"video_subject": "",
"video_aspect": "横屏 16:9西瓜视频",
"voice_name": "女生-晓晓",
"enable_bgm": false,
"font_name": "STHeitiMedium 黑体-中",
"text_color": "#FFFFFF",
"font_size": 60,
"stroke_color": "#000000",
"stroke_width": 1.5
}
"""
logger.info(f"start task: {task_id}")
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=5)
video_subject = params.video_subject
voice_name = voice.parse_voice_name(params.voice_name)
voice_rate = params.voice_rate
paragraph_number = params.paragraph_number
n_threads = params.n_threads
max_clip_duration = params.video_clip_duration
logger.info("\n\n## generating video script") logger.info("\n\n## generating video script")
video_script = params.video_script.strip() video_script = params.video_script.strip()
if not video_script: if not video_script:
video_script = llm.generate_script(video_subject=video_subject, language=params.video_language, video_script = llm.generate_script(
paragraph_number=paragraph_number) video_subject=params.video_subject,
language=params.video_language,
paragraph_number=params.paragraph_number,
)
else: else:
logger.debug(f"video script: \n{video_script}") logger.debug(f"video script: \n{video_script}")
if not video_script: if not video_script:
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED) sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
logger.error("failed to generate video script.") logger.error("failed to generate video script.")
return return None
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=10) return video_script
def generate_terms(task_id, params, video_script):
logger.info("\n\n## generating video terms") logger.info("\n\n## generating video terms")
video_terms = params.video_terms video_terms = params.video_terms
if not video_terms: if not video_terms:
video_terms = llm.generate_terms(video_subject=video_subject, video_script=video_script, amount=5) video_terms = llm.generate_terms(
video_subject=params.video_subject, video_script=video_script, amount=5
)
else: else:
if isinstance(video_terms, str): if isinstance(video_terms, str):
video_terms = [term.strip() for term in re.split(r'[,]', video_terms)] video_terms = [term.strip() for term in re.split(r"[,]", video_terms)]
elif isinstance(video_terms, list): elif isinstance(video_terms, list):
video_terms = [term.strip() for term in video_terms] video_terms = [term.strip() for term in video_terms]
else: else:
@@ -69,9 +53,13 @@ def start(task_id, params: VideoParams):
if not video_terms: if not video_terms:
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED) sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
logger.error("failed to generate video terms.") logger.error("failed to generate video terms.")
return return None
script_file = path.join(utils.task_dir(task_id), f"script.json") return video_terms
def save_script_data(task_id, video_script, video_terms, params):
script_file = path.join(utils.task_dir(task_id), "script.json")
script_data = { script_data = {
"script": video_script, "script": video_script,
"search_terms": video_terms, "search_terms": video_terms,
@@ -81,11 +69,16 @@ def start(task_id, params: VideoParams):
with open(script_file, "w", encoding="utf-8") as f: with open(script_file, "w", encoding="utf-8") as f:
f.write(utils.to_json(script_data)) f.write(utils.to_json(script_data))
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=20)
def generate_audio(task_id, params, video_script):
logger.info("\n\n## generating audio") logger.info("\n\n## generating audio")
audio_file = path.join(utils.task_dir(task_id), f"audio.mp3") audio_file = path.join(utils.task_dir(task_id), "audio.mp3")
sub_maker = voice.tts(text=video_script, voice_name=voice_name, voice_rate=voice_rate, voice_file=audio_file) sub_maker = voice.tts(
text=video_script,
voice_name=voice.parse_voice_name(params.voice_name),
voice_rate=params.voice_rate,
voice_file=audio_file,
)
if sub_maker is None: if sub_maker is None:
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED) sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
logger.error( logger.error(
@@ -94,86 +87,100 @@ def start(task_id, params: VideoParams):
2. check if the network is available. If you are in China, it is recommended to use a VPN and enable the global traffic mode. 2. check if the network is available. If you are in China, it is recommended to use a VPN and enable the global traffic mode.
""".strip() """.strip()
) )
return return None, None
audio_duration = voice.get_audio_duration(sub_maker) audio_duration = math.ceil(voice.get_audio_duration(sub_maker))
audio_duration = math.ceil(audio_duration) return audio_file, audio_duration
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=30)
subtitle_path = "" def generate_subtitle(task_id, params, video_script, sub_maker, audio_file):
if params.subtitle_enabled: if not params.subtitle_enabled:
subtitle_path = path.join(utils.task_dir(task_id), f"subtitle.srt") return ""
subtitle_provider = config.app.get("subtitle_provider", "").strip().lower()
logger.info(f"\n\n## generating subtitle, provider: {subtitle_provider}")
subtitle_fallback = False
if subtitle_provider == "edge":
voice.create_subtitle(text=video_script, sub_maker=sub_maker, subtitle_file=subtitle_path)
if not os.path.exists(subtitle_path):
subtitle_fallback = True
logger.warning("subtitle file not found, fallback to whisper")
if subtitle_provider == "whisper" or subtitle_fallback: subtitle_path = path.join(utils.task_dir(task_id), "subtitle.srt")
subtitle.create(audio_file=audio_file, subtitle_file=subtitle_path) subtitle_provider = config.app.get("subtitle_provider", "").strip().lower()
logger.info("\n\n## correcting subtitle") logger.info(f"\n\n## generating subtitle, provider: {subtitle_provider}")
subtitle.correct(subtitle_file=subtitle_path, video_script=video_script)
subtitle_lines = subtitle.file_to_subtitles(subtitle_path) subtitle_fallback = False
if not subtitle_lines: if subtitle_provider == "edge":
logger.warning(f"subtitle file is invalid: {subtitle_path}") voice.create_subtitle(
subtitle_path = "" text=video_script, sub_maker=sub_maker, subtitle_file=subtitle_path
)
if not os.path.exists(subtitle_path):
subtitle_fallback = True
logger.warning("subtitle file not found, fallback to whisper")
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=40) if subtitle_provider == "whisper" or subtitle_fallback:
subtitle.create(audio_file=audio_file, subtitle_file=subtitle_path)
logger.info("\n\n## correcting subtitle")
subtitle.correct(subtitle_file=subtitle_path, video_script=video_script)
downloaded_videos = [] subtitle_lines = subtitle.file_to_subtitles(subtitle_path)
if not subtitle_lines:
logger.warning(f"subtitle file is invalid: {subtitle_path}")
return ""
return subtitle_path
def get_video_materials(task_id, params, video_terms, audio_duration):
if params.video_source == "local": if params.video_source == "local":
logger.info("\n\n## preprocess local materials") logger.info("\n\n## preprocess local materials")
materials = video.preprocess_video(materials=params.video_materials, clip_duration=max_clip_duration) materials = video.preprocess_video(
print(materials) materials=params.video_materials, clip_duration=params.video_clip_duration
)
if not materials: if not materials:
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED) sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
logger.error("no valid materials found, please check the materials and try again.") logger.error(
return "no valid materials found, please check the materials and try again."
for material_info in materials: )
print(material_info) return None
downloaded_videos.append(material_info.url) return [material_info.url for material_info in materials]
else: else:
logger.info(f"\n\n## downloading videos from {params.video_source}") logger.info(f"\n\n## downloading videos from {params.video_source}")
downloaded_videos = material.download_videos(task_id=task_id, downloaded_videos = material.download_videos(
search_terms=video_terms, task_id=task_id,
source=params.video_source, search_terms=video_terms,
video_aspect=params.video_aspect, source=params.video_source,
video_contact_mode=params.video_concat_mode, video_aspect=params.video_aspect,
audio_duration=audio_duration * params.video_count, video_contact_mode=params.video_concat_mode,
max_clip_duration=max_clip_duration, audio_duration=audio_duration * params.video_count,
) max_clip_duration=params.video_clip_duration,
if not downloaded_videos: )
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED) if not downloaded_videos:
logger.error( sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
"failed to download videos, maybe the network is not available. if you are in China, please use a VPN.") logger.error(
return "failed to download videos, maybe the network is not available. if you are in China, please use a VPN."
)
return None
return downloaded_videos
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=50)
def generate_final_videos(
task_id, params, downloaded_videos, audio_file, subtitle_path
):
final_video_paths = [] final_video_paths = []
combined_video_paths = [] combined_video_paths = []
video_concat_mode = params.video_concat_mode video_concat_mode = (
if params.video_count > 1: params.video_concat_mode if params.video_count > 1 else VideoConcatMode.random
video_concat_mode = VideoConcatMode.random )
_progress = 50 _progress = 50
for i in range(params.video_count): for i in range(params.video_count):
index = i + 1 index = i + 1
combined_video_path = path.join(utils.task_dir(task_id), f"combined-{index}.mp4") combined_video_path = path.join(
utils.task_dir(task_id), f"combined-{index}.mp4"
)
logger.info(f"\n\n## combining video: {index} => {combined_video_path}") logger.info(f"\n\n## combining video: {index} => {combined_video_path}")
video.combine_videos(combined_video_path=combined_video_path, video.combine_videos(
video_paths=downloaded_videos, combined_video_path=combined_video_path,
audio_file=audio_file, video_paths=downloaded_videos,
video_aspect=params.video_aspect, audio_file=audio_file,
video_concat_mode=video_concat_mode, video_aspect=params.video_aspect,
max_clip_duration=max_clip_duration, video_concat_mode=video_concat_mode,
threads=n_threads) max_clip_duration=params.video_clip_duration,
threads=params.n_threads,
)
_progress += 50 / params.video_count / 2 _progress += 50 / params.video_count / 2
sm.state.update_task(task_id, progress=_progress) sm.state.update_task(task_id, progress=_progress)
@@ -181,13 +188,13 @@ def start(task_id, params: VideoParams):
final_video_path = path.join(utils.task_dir(task_id), f"final-{index}.mp4") final_video_path = path.join(utils.task_dir(task_id), f"final-{index}.mp4")
logger.info(f"\n\n## generating video: {index} => {final_video_path}") logger.info(f"\n\n## generating video: {index} => {final_video_path}")
# Put everything together video.generate_video(
video.generate_video(video_path=combined_video_path, video_path=combined_video_path,
audio_path=audio_file, audio_path=audio_file,
subtitle_path=subtitle_path, subtitle_path=subtitle_path,
output_file=final_video_path, output_file=final_video_path,
params=params, params=params,
) )
_progress += 50 / params.video_count / 2 _progress += 50 / params.video_count / 2
sm.state.update_task(task_id, progress=_progress) sm.state.update_task(task_id, progress=_progress)
@@ -195,16 +202,119 @@ def start(task_id, params: VideoParams):
final_video_paths.append(final_video_path) final_video_paths.append(final_video_path)
combined_video_paths.append(combined_video_path) combined_video_paths.append(combined_video_path)
logger.success(f"task {task_id} finished, generated {len(final_video_paths)} videos.") return final_video_paths, combined_video_paths
def start(task_id, params: VideoParams, stop_at: str = "video"):
logger.info(f"start task: {task_id}, stop_at: {stop_at}")
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=5)
# 1. Generate script
video_script = generate_script(task_id, params)
if not video_script:
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
return
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=10)
if stop_at == "script":
sm.state.update_task(
task_id, state=const.TASK_STATE_COMPLETE, progress=100, script=video_script
)
return {"script": video_script}
# 2. Generate terms
video_terms = ""
if params.video_source != "local":
video_terms = generate_terms(task_id, params, video_script)
if not video_terms:
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
return
save_script_data(task_id, video_script, video_terms, params)
if stop_at == "terms":
sm.state.update_task(
task_id, state=const.TASK_STATE_COMPLETE, progress=100, terms=video_terms
)
return {"script": video_script, "terms": video_terms}
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=20)
# 3. Generate audio
audio_file, audio_duration = generate_audio(task_id, params, video_script)
if not audio_file:
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
return
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=30)
if stop_at == "audio":
sm.state.update_task(
task_id,
state=const.TASK_STATE_COMPLETE,
progress=100,
audio_file=audio_file,
)
return {"audio_file": audio_file, "audio_duration": audio_duration}
# 4. Generate subtitle
subtitle_path = generate_subtitle(task_id, params, video_script, None, audio_file)
if stop_at == "subtitle":
sm.state.update_task(
task_id,
state=const.TASK_STATE_COMPLETE,
progress=100,
subtitle_path=subtitle_path,
)
return {"subtitle_path": subtitle_path}
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=40)
# 5. Get video materials
downloaded_videos = get_video_materials(
task_id, params, video_terms, audio_duration
)
if not downloaded_videos:
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
return
if stop_at == "materials":
sm.state.update_task(
task_id,
state=const.TASK_STATE_COMPLETE,
progress=100,
materials=downloaded_videos,
)
return {"materials": downloaded_videos}
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=50)
# 6. Generate final videos
final_video_paths, combined_video_paths = generate_final_videos(
task_id, params, downloaded_videos, audio_file, subtitle_path
)
if not final_video_paths:
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
return
logger.success(
f"task {task_id} finished, generated {len(final_video_paths)} videos."
)
kwargs = { kwargs = {
"videos": final_video_paths, "videos": final_video_paths,
"combined_videos": combined_video_paths "combined_videos": combined_video_paths,
"script": video_script,
"terms": video_terms,
"audio_file": audio_file,
"audio_duration": audio_duration,
"subtitle_path": subtitle_path,
"materials": downloaded_videos,
} }
sm.state.update_task(task_id, state=const.TASK_STATE_COMPLETE, progress=100, **kwargs) sm.state.update_task(
task_id, state=const.TASK_STATE_COMPLETE, progress=100, **kwargs
)
return kwargs return kwargs
# def start_test(task_id, params: VideoParams):
# print(f"start task {task_id} \n")
# time.sleep(5)
# print(f"task {task_id} finished \n")