Format project code

This commit is contained in:
yyhhyyyyyy
2024-07-24 14:59:06 +08:00
parent bbd4e94941
commit 905841965a
18 changed files with 410 additions and 214 deletions

View File

@@ -1,4 +1,5 @@
"""Application implementation - ASGI.""" """Application implementation - ASGI."""
import os import os
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
@@ -24,7 +25,9 @@ def exception_handler(request: Request, e: HttpException):
def validation_exception_handler(request: Request, e: RequestValidationError): def validation_exception_handler(request: Request, e: RequestValidationError):
return JSONResponse( return JSONResponse(
status_code=400, status_code=400,
content=utils.get_response(status=400, data=e.errors(), message='field required'), content=utils.get_response(
status=400, data=e.errors(), message="field required"
),
) )
@@ -61,7 +64,9 @@ app.add_middleware(
) )
task_dir = utils.task_dir() task_dir = utils.task_dir()
app.mount("/tasks", StaticFiles(directory=task_dir, html=True, follow_symlink=True), name="") app.mount(
"/tasks", StaticFiles(directory=task_dir, html=True, follow_symlink=True), name=""
)
public_dir = utils.public_dir() public_dir = utils.public_dir()
app.mount("/", StaticFiles(directory=public_dir, html=True), name="") app.mount("/", StaticFiles(directory=public_dir, html=True), name="")

View File

@@ -10,7 +10,9 @@ from app.utils import utils
def __init_logger(): def __init_logger():
# _log_file = utils.storage_dir("logs/server.log") # _log_file = utils.storage_dir("logs/server.log")
_lvl = config.log_level _lvl = config.log_level
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) root_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
)
def format_record(record): def format_record(record):
# 获取日志记录中的文件全路径 # 获取日志记录中的文件全路径
@@ -21,10 +23,13 @@ def __init_logger():
record["file"].path = f"./{relative_path}" record["file"].path = f"./{relative_path}"
# 返回修改后的格式字符串 # 返回修改后的格式字符串
# 您可以根据需要调整这里的格式 # 您可以根据需要调整这里的格式
_format = '<green>{time:%Y-%m-%d %H:%M:%S}</> | ' + \ _format = (
'<level>{level}</> | ' + \ "<green>{time:%Y-%m-%d %H:%M:%S}</> | "
'"{file.path}:{line}":<blue> {function}</> ' + \ + "<level>{level}</> | "
'- <level>{message}</>' + "\n" + '"{file.path}:{line}":<blue> {function}</> '
+ "- <level>{message}</>"
+ "\n"
)
return _format return _format
logger.remove() logger.remove()

View File

@@ -25,7 +25,7 @@ def load_config():
_config_ = toml.load(config_file) _config_ = toml.load(config_file)
except Exception as e: except Exception as e:
logger.warning(f"load config failed: {str(e)}, try to load as utf-8-sig") logger.warning(f"load config failed: {str(e)}, try to load as utf-8-sig")
with open(config_file, mode="r", encoding='utf-8-sig') as fp: with open(config_file, mode="r", encoding="utf-8-sig") as fp:
_cfg_content = fp.read() _cfg_content = fp.read()
_config_ = toml.loads(_cfg_content) _config_ = toml.loads(_cfg_content)
return _config_ return _config_
@@ -52,8 +52,10 @@ log_level = _cfg.get("log_level", "DEBUG")
listen_host = _cfg.get("listen_host", "0.0.0.0") listen_host = _cfg.get("listen_host", "0.0.0.0")
listen_port = _cfg.get("listen_port", 8080) listen_port = _cfg.get("listen_port", 8080)
project_name = _cfg.get("project_name", "MoneyPrinterTurbo") project_name = _cfg.get("project_name", "MoneyPrinterTurbo")
project_description = _cfg.get("project_description", project_description = _cfg.get(
"<a href='https://github.com/harry0703/MoneyPrinterTurbo'>https://github.com/harry0703/MoneyPrinterTurbo</a>") "project_description",
"<a href='https://github.com/harry0703/MoneyPrinterTurbo'>https://github.com/harry0703/MoneyPrinterTurbo</a>",
)
project_version = _cfg.get("project_version", "1.1.9") project_version = _cfg.get("project_version", "1.1.9")
reload_debug = False reload_debug = False

View File

@@ -7,14 +7,14 @@ from app.models.exception import HttpException
def get_task_id(request: Request): def get_task_id(request: Request):
task_id = request.headers.get('x-task-id') task_id = request.headers.get("x-task-id")
if not task_id: if not task_id:
task_id = uuid4() task_id = uuid4()
return str(task_id) return str(task_id)
def get_api_key(request: Request): def get_api_key(request: Request):
api_key = request.headers.get('x-api-key') api_key = request.headers.get("x-api-key")
return api_key return api_key
@@ -23,5 +23,9 @@ def verify_token(request: Request):
if token != config.app.get("api_key", ""): if token != config.app.get("api_key", ""):
request_id = get_task_id(request) request_id = get_task_id(request)
request_url = request.url request_url = request.url
user_agent = request.headers.get('user-agent') user_agent = request.headers.get("user-agent")
raise HttpException(task_id=request_id, status_code=401, message=f"invalid token: {request_url}, {user_agent}") raise HttpException(
task_id=request_id,
status_code=401,
message=f"invalid token: {request_url}, {user_agent}",
)

View File

@@ -18,11 +18,15 @@ class TaskManager:
print(f"add task: {func.__name__}, current_tasks: {self.current_tasks}") print(f"add task: {func.__name__}, current_tasks: {self.current_tasks}")
self.execute_task(func, *args, **kwargs) self.execute_task(func, *args, **kwargs)
else: else:
print(f"enqueue task: {func.__name__}, current_tasks: {self.current_tasks}") print(
f"enqueue task: {func.__name__}, current_tasks: {self.current_tasks}"
)
self.enqueue({"func": func, "args": args, "kwargs": kwargs}) self.enqueue({"func": func, "args": args, "kwargs": kwargs})
def execute_task(self, func: Callable, *args: Any, **kwargs: Any): def execute_task(self, func: Callable, *args: Any, **kwargs: Any):
thread = threading.Thread(target=self.run_task, args=(func, *args), kwargs=kwargs) thread = threading.Thread(
target=self.run_task, args=(func, *args), kwargs=kwargs
)
thread.start() thread.start()
def run_task(self, func: Callable, *args: Any, **kwargs: Any): def run_task(self, func: Callable, *args: Any, **kwargs: Any):
@@ -35,11 +39,14 @@ class TaskManager:
def check_queue(self): def check_queue(self):
with self.lock: with self.lock:
if self.current_tasks < self.max_concurrent_tasks and not self.is_queue_empty(): if (
self.current_tasks < self.max_concurrent_tasks
and not self.is_queue_empty()
):
task_info = self.dequeue() task_info = self.dequeue()
func = task_info['func'] func = task_info["func"]
args = task_info.get('args', ()) args = task_info.get("args", ())
kwargs = task_info.get('kwargs', {}) kwargs = task_info.get("kwargs", {})
self.execute_task(func, *args, **kwargs) self.execute_task(func, *args, **kwargs)
def task_done(self): def task_done(self):

View File

@@ -8,7 +8,7 @@ from app.models.schema import VideoParams
from app.services import task as tm from app.services import task as tm
FUNC_MAP = { FUNC_MAP = {
'start': tm.start, "start": tm.start,
# 'start_test': tm.start_test # 'start_test': tm.start_test
} }
@@ -24,11 +24,15 @@ class RedisTaskManager(TaskManager):
def enqueue(self, task: Dict): def enqueue(self, task: Dict):
task_with_serializable_params = task.copy() task_with_serializable_params = task.copy()
if 'params' in task['kwargs'] and isinstance(task['kwargs']['params'], VideoParams): if "params" in task["kwargs"] and isinstance(
task_with_serializable_params['kwargs']['params'] = task['kwargs']['params'].dict() task["kwargs"]["params"], VideoParams
):
task_with_serializable_params["kwargs"]["params"] = task["kwargs"][
"params"
].dict()
# 将函数对象转换为其名称 # 将函数对象转换为其名称
task_with_serializable_params['func'] = task['func'].__name__ task_with_serializable_params["func"] = task["func"].__name__
self.redis_client.rpush(self.queue, json.dumps(task_with_serializable_params)) self.redis_client.rpush(self.queue, json.dumps(task_with_serializable_params))
def dequeue(self): def dequeue(self):
@@ -36,10 +40,14 @@ class RedisTaskManager(TaskManager):
if task_json: if task_json:
task_info = json.loads(task_json) task_info = json.loads(task_json)
# 将函数名称转换回函数对象 # 将函数名称转换回函数对象
task_info['func'] = FUNC_MAP[task_info['func']] task_info["func"] = FUNC_MAP[task_info["func"]]
if 'params' in task_info['kwargs'] and isinstance(task_info['kwargs']['params'], dict): if "params" in task_info["kwargs"] and isinstance(
task_info['kwargs']['params'] = VideoParams(**task_info['kwargs']['params']) task_info["kwargs"]["params"], dict
):
task_info["kwargs"]["params"] = VideoParams(
**task_info["kwargs"]["params"]
)
return task_info return task_info
return None return None

View File

@@ -4,6 +4,11 @@ from fastapi import Request
router = APIRouter() router = APIRouter()
@router.get("/ping", tags=["Health Check"], description="检查服务可用性", response_description="pong") @router.get(
"/ping",
tags=["Health Check"],
description="检查服务可用性",
response_description="pong",
)
def ping(request: Request) -> str: def ping(request: Request) -> str:
return "pong" return "pong"

View File

@@ -3,8 +3,8 @@ from fastapi import APIRouter, Depends
def new_router(dependencies=None): def new_router(dependencies=None):
router = APIRouter() router = APIRouter()
router.tags = ['V1'] router.tags = ["V1"]
router.prefix = '/api/v1' router.prefix = "/api/v1"
# 将认证依赖项应用于所有路由 # 将认证依赖项应用于所有路由
if dependencies: if dependencies:
router.dependencies = dependencies router.dependencies = dependencies

View File

@@ -1,6 +1,11 @@
from fastapi import Request from fastapi import Request
from app.controllers.v1.base import new_router from app.controllers.v1.base import new_router
from app.models.schema import VideoScriptResponse, VideoScriptRequest, VideoTermsResponse, VideoTermsRequest from app.models.schema import (
VideoScriptResponse,
VideoScriptRequest,
VideoTermsResponse,
VideoTermsRequest,
)
from app.services import llm from app.services import llm
from app.utils import utils from app.utils import utils
@@ -9,23 +14,31 @@ from app.utils import utils
router = new_router() router = new_router()
@router.post("/scripts", response_model=VideoScriptResponse, summary="Create a script for the video") @router.post(
"/scripts",
response_model=VideoScriptResponse,
summary="Create a script for the video",
)
def generate_video_script(request: Request, body: VideoScriptRequest): def generate_video_script(request: Request, body: VideoScriptRequest):
video_script = llm.generate_script(video_subject=body.video_subject, video_script = llm.generate_script(
language=body.video_language, video_subject=body.video_subject,
paragraph_number=body.paragraph_number) language=body.video_language,
response = { paragraph_number=body.paragraph_number,
"video_script": video_script )
} response = {"video_script": video_script}
return utils.get_response(200, response) return utils.get_response(200, response)
@router.post("/terms", response_model=VideoTermsResponse, summary="Generate video terms based on the video script") @router.post(
"/terms",
response_model=VideoTermsResponse,
summary="Generate video terms based on the video script",
)
def generate_video_terms(request: Request, body: VideoTermsRequest): def generate_video_terms(request: Request, body: VideoTermsRequest):
video_terms = llm.generate_terms(video_subject=body.video_subject, video_terms = llm.generate_terms(
video_script=body.video_script, video_subject=body.video_subject,
amount=body.amount) video_script=body.video_script,
response = { amount=body.amount,
"video_terms": video_terms )
} response = {"video_terms": video_terms}
return utils.get_response(200, response) return utils.get_response(200, response)

View File

@@ -1,11 +1,25 @@
PUNCTUATIONS = [ PUNCTUATIONS = [
"?", ",", ".", "", ";", ":", "!", "", "?",
"", "", "", "", "", "", "", "...", ",",
".",
"",
";",
":",
"!",
"",
"",
"",
"",
"",
"",
"",
"",
"...",
] ]
TASK_STATE_FAILED = -1 TASK_STATE_FAILED = -1
TASK_STATE_COMPLETE = 1 TASK_STATE_COMPLETE = 1
TASK_STATE_PROCESSING = 4 TASK_STATE_PROCESSING = 4
FILE_TYPE_VIDEOS = ['mp4', 'mov', 'mkv', 'webm'] FILE_TYPE_VIDEOS = ["mp4", "mov", "mkv", "webm"]
FILE_TYPE_IMAGES = ['jpg', 'jpeg', 'png', 'bmp'] FILE_TYPE_IMAGES = ["jpg", "jpeg", "png", "bmp"]

View File

@@ -5,16 +5,18 @@ from loguru import logger
class HttpException(Exception): class HttpException(Exception):
def __init__(self, task_id: str, status_code: int, message: str = '', data: Any = None): def __init__(
self, task_id: str, status_code: int, message: str = "", data: Any = None
):
self.message = message self.message = message
self.status_code = status_code self.status_code = status_code
self.data = data self.data = data
# 获取异常堆栈信息 # 获取异常堆栈信息
tb_str = traceback.format_exc().strip() tb_str = traceback.format_exc().strip()
if not tb_str or tb_str == "NoneType: None": if not tb_str or tb_str == "NoneType: None":
msg = f'HttpException: {status_code}, {task_id}, {message}' msg = f"HttpException: {status_code}, {task_id}, {message}"
else: else:
msg = f'HttpException: {status_code}, {task_id}, {message}\n{tb_str}' msg = f"HttpException: {status_code}, {task_id}, {message}\n{tb_str}"
if status_code == 400: if status_code == 400:
logger.warning(msg) logger.warning(msg)

View File

@@ -21,6 +21,7 @@ def _generate_response(prompt: str) -> str:
if not model_name: if not model_name:
model_name = "gpt-3.5-turbo-16k-0613" model_name = "gpt-3.5-turbo-16k-0613"
import g4f import g4f
content = g4f.ChatCompletion.create( content = g4f.ChatCompletion.create(
model=model_name, model=model_name,
messages=[{"role": "user", "content": prompt}], messages=[{"role": "user", "content": prompt}],
@@ -78,44 +79,56 @@ def _generate_response(prompt: str) -> str:
base_url = config.app.get("ernie_base_url") base_url = config.app.get("ernie_base_url")
model_name = "***" model_name = "***"
if not secret_key: if not secret_key:
raise ValueError(f"{llm_provider}: secret_key is not set, please set it in the config.toml file.") raise ValueError(
f"{llm_provider}: secret_key is not set, please set it in the config.toml file."
)
else: else:
raise ValueError("llm_provider is not set, please set it in the config.toml file.") raise ValueError(
"llm_provider is not set, please set it in the config.toml file."
)
if not api_key: if not api_key:
raise ValueError(f"{llm_provider}: api_key is not set, please set it in the config.toml file.") raise ValueError(
f"{llm_provider}: api_key is not set, please set it in the config.toml file."
)
if not model_name: if not model_name:
raise ValueError(f"{llm_provider}: model_name is not set, please set it in the config.toml file.") raise ValueError(
f"{llm_provider}: model_name is not set, please set it in the config.toml file."
)
if not base_url: if not base_url:
raise ValueError(f"{llm_provider}: base_url is not set, please set it in the config.toml file.") raise ValueError(
f"{llm_provider}: base_url is not set, please set it in the config.toml file."
)
if llm_provider == "qwen": if llm_provider == "qwen":
import dashscope import dashscope
from dashscope.api_entities.dashscope_response import GenerationResponse from dashscope.api_entities.dashscope_response import GenerationResponse
dashscope.api_key = api_key dashscope.api_key = api_key
response = dashscope.Generation.call( response = dashscope.Generation.call(
model=model_name, model=model_name, messages=[{"role": "user", "content": prompt}]
messages=[{"role": "user", "content": prompt}]
) )
if response: if response:
if isinstance(response, GenerationResponse): if isinstance(response, GenerationResponse):
status_code = response.status_code status_code = response.status_code
if status_code != 200: if status_code != 200:
raise Exception( raise Exception(
f"[{llm_provider}] returned an error response: \"{response}\"") f'[{llm_provider}] returned an error response: "{response}"'
)
content = response["output"]["text"] content = response["output"]["text"]
return content.replace("\n", "") return content.replace("\n", "")
else: else:
raise Exception( raise Exception(
f"[{llm_provider}] returned an invalid response: \"{response}\"") f'[{llm_provider}] returned an invalid response: "{response}"'
)
else: else:
raise Exception( raise Exception(f"[{llm_provider}] returned an empty response")
f"[{llm_provider}] returned an empty response")
if llm_provider == "gemini": if llm_provider == "gemini":
import google.generativeai as genai import google.generativeai as genai
genai.configure(api_key=api_key, transport='rest')
genai.configure(api_key=api_key, transport="rest")
generation_config = { generation_config = {
"temperature": 0.5, "temperature": 0.5,
@@ -127,25 +140,27 @@ def _generate_response(prompt: str) -> str:
safety_settings = [ safety_settings = [
{ {
"category": "HARM_CATEGORY_HARASSMENT", "category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_ONLY_HIGH" "threshold": "BLOCK_ONLY_HIGH",
}, },
{ {
"category": "HARM_CATEGORY_HATE_SPEECH", "category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_ONLY_HIGH" "threshold": "BLOCK_ONLY_HIGH",
}, },
{ {
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_ONLY_HIGH" "threshold": "BLOCK_ONLY_HIGH",
}, },
{ {
"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_ONLY_HIGH" "threshold": "BLOCK_ONLY_HIGH",
}, },
] ]
model = genai.GenerativeModel(model_name=model_name, model = genai.GenerativeModel(
generation_config=generation_config, model_name=model_name,
safety_settings=safety_settings) generation_config=generation_config,
safety_settings=safety_settings,
)
try: try:
response = model.generate_content(prompt) response = model.generate_content(prompt)
@@ -158,15 +173,16 @@ def _generate_response(prompt: str) -> str:
if llm_provider == "cloudflare": if llm_provider == "cloudflare":
import requests import requests
response = requests.post( response = requests.post(
f"https://api.cloudflare.com/client/v4/accounts/{account_id}/ai/run/{model_name}", f"https://api.cloudflare.com/client/v4/accounts/{account_id}/ai/run/{model_name}",
headers={"Authorization": f"Bearer {api_key}"}, headers={"Authorization": f"Bearer {api_key}"},
json={ json={
"messages": [ "messages": [
{"role": "system", "content": "You are a friendly assistant"}, {"role": "system", "content": "You are a friendly assistant"},
{"role": "user", "content": prompt} {"role": "user", "content": prompt},
] ]
} },
) )
result = response.json() result = response.json()
logger.info(result) logger.info(result)
@@ -174,30 +190,35 @@ def _generate_response(prompt: str) -> str:
if llm_provider == "ernie": if llm_provider == "ernie":
import requests import requests
params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key}
access_token = requests.post("https://aip.baidubce.com/oauth/2.0/token", params=params).json().get( params = {
"access_token") "grant_type": "client_credentials",
"client_id": api_key,
"client_secret": secret_key,
}
access_token = (
requests.post("https://aip.baidubce.com/oauth/2.0/token", params=params)
.json()
.get("access_token")
)
url = f"{base_url}?access_token={access_token}" url = f"{base_url}?access_token={access_token}"
payload = json.dumps({ payload = json.dumps(
"messages": [ {
{ "messages": [{"role": "user", "content": prompt}],
"role": "user", "temperature": 0.5,
"content": prompt "top_p": 0.8,
} "penalty_score": 1,
], "disable_search": False,
"temperature": 0.5, "enable_citation": False,
"top_p": 0.8, "response_format": "text",
"penalty_score": 1, }
"disable_search": False, )
"enable_citation": False, headers = {"Content-Type": "application/json"}
"response_format": "text"
})
headers = {
'Content-Type': 'application/json'
}
response = requests.request("POST", url, headers=headers, data=payload).json() response = requests.request(
"POST", url, headers=headers, data=payload
).json()
return response.get("result") return response.get("result")
if llm_provider == "azure": if llm_provider == "azure":
@@ -213,24 +234,27 @@ def _generate_response(prompt: str) -> str:
) )
response = client.chat.completions.create( response = client.chat.completions.create(
model=model_name, model=model_name, messages=[{"role": "user", "content": prompt}]
messages=[{"role": "user", "content": prompt}]
) )
if response: if response:
if isinstance(response, ChatCompletion): if isinstance(response, ChatCompletion):
content = response.choices[0].message.content content = response.choices[0].message.content
else: else:
raise Exception( raise Exception(
f"[{llm_provider}] returned an invalid response: \"{response}\", please check your network " f'[{llm_provider}] returned an invalid response: "{response}", please check your network '
f"connection and try again.") f"connection and try again."
)
else: else:
raise Exception( raise Exception(
f"[{llm_provider}] returned an empty response, please check your network connection and try again.") f"[{llm_provider}] returned an empty response, please check your network connection and try again."
)
return content.replace("\n", "") return content.replace("\n", "")
def generate_script(video_subject: str, language: str = "", paragraph_number: int = 1) -> str: def generate_script(
video_subject: str, language: str = "", paragraph_number: int = 1
) -> str:
prompt = f""" prompt = f"""
# Role: Video Script Generator # Role: Video Script Generator
@@ -335,14 +359,16 @@ Please note that you must use English for generating video search terms; Chinese
try: try:
response = _generate_response(prompt) response = _generate_response(prompt)
search_terms = json.loads(response) search_terms = json.loads(response)
if not isinstance(search_terms, list) or not all(isinstance(term, str) for term in search_terms): if not isinstance(search_terms, list) or not all(
isinstance(term, str) for term in search_terms
):
logger.error("response is not a list of strings.") logger.error("response is not a list of strings.")
continue continue
except Exception as e: except Exception as e:
logger.warning(f"failed to generate video terms: {str(e)}") logger.warning(f"failed to generate video terms: {str(e)}")
if response: if response:
match = re.search(r'\[.*]', response) match = re.search(r"\[.*]", response)
if match: if match:
try: try:
search_terms = json.loads(match.group()) search_terms = json.loads(match.group())
@@ -361,9 +387,13 @@ Please note that you must use English for generating video search terms; Chinese
if __name__ == "__main__": if __name__ == "__main__":
video_subject = "生命的意义是什么" video_subject = "生命的意义是什么"
script = generate_script(video_subject=video_subject, language="zh-CN", paragraph_number=1) script = generate_script(
video_subject=video_subject, language="zh-CN", paragraph_number=1
)
print("######################") print("######################")
print(script) print(script)
search_terms = generate_terms(video_subject=video_subject, video_script=script, amount=5) search_terms = generate_terms(
video_subject=video_subject, video_script=script, amount=5
)
print("######################") print("######################")
print(search_terms) print(search_terms)

View File

@@ -19,7 +19,8 @@ def get_api_key(cfg_key: str):
if not api_keys: if not api_keys:
raise ValueError( raise ValueError(
f"\n\n##### {cfg_key} is not set #####\n\nPlease set it in the config.toml file: {config.config_file}\n\n" f"\n\n##### {cfg_key} is not set #####\n\nPlease set it in the config.toml file: {config.config_file}\n\n"
f"{utils.to_json(config.app)}") f"{utils.to_json(config.app)}"
)
# if only one key is provided, return it # if only one key is provided, return it
if isinstance(api_keys, str): if isinstance(api_keys, str):
@@ -30,28 +31,29 @@ def get_api_key(cfg_key: str):
return api_keys[requested_count % len(api_keys)] return api_keys[requested_count % len(api_keys)]
def search_videos_pexels(search_term: str, def search_videos_pexels(
minimum_duration: int, search_term: str,
video_aspect: VideoAspect = VideoAspect.portrait, minimum_duration: int,
) -> List[MaterialInfo]: video_aspect: VideoAspect = VideoAspect.portrait,
) -> List[MaterialInfo]:
aspect = VideoAspect(video_aspect) aspect = VideoAspect(video_aspect)
video_orientation = aspect.name video_orientation = aspect.name
video_width, video_height = aspect.to_resolution() video_width, video_height = aspect.to_resolution()
api_key = get_api_key("pexels_api_keys") api_key = get_api_key("pexels_api_keys")
headers = { headers = {"Authorization": api_key}
"Authorization": api_key
}
# Build URL # Build URL
params = { params = {"query": search_term, "per_page": 20, "orientation": video_orientation}
"query": search_term,
"per_page": 20,
"orientation": video_orientation
}
query_url = f"https://api.pexels.com/videos/search?{urlencode(params)}" query_url = f"https://api.pexels.com/videos/search?{urlencode(params)}"
logger.info(f"searching videos: {query_url}, with proxies: {config.proxy}") logger.info(f"searching videos: {query_url}, with proxies: {config.proxy}")
try: try:
r = requests.get(query_url, headers=headers, proxies=config.proxy, verify=False, timeout=(30, 60)) r = requests.get(
query_url,
headers=headers,
proxies=config.proxy,
verify=False,
timeout=(30, 60),
)
response = r.json() response = r.json()
video_items = [] video_items = []
if "videos" not in response: if "videos" not in response:
@@ -83,10 +85,11 @@ def search_videos_pexels(search_term: str,
return [] return []
def search_videos_pixabay(search_term: str, def search_videos_pixabay(
minimum_duration: int, search_term: str,
video_aspect: VideoAspect = VideoAspect.portrait, minimum_duration: int,
) -> List[MaterialInfo]: video_aspect: VideoAspect = VideoAspect.portrait,
) -> List[MaterialInfo]:
aspect = VideoAspect(video_aspect) aspect = VideoAspect(video_aspect)
video_width, video_height = aspect.to_resolution() video_width, video_height = aspect.to_resolution()
@@ -97,13 +100,15 @@ def search_videos_pixabay(search_term: str,
"q": search_term, "q": search_term,
"video_type": "all", # Accepted values: "all", "film", "animation" "video_type": "all", # Accepted values: "all", "film", "animation"
"per_page": 50, "per_page": 50,
"key": api_key "key": api_key,
} }
query_url = f"https://pixabay.com/api/videos/?{urlencode(params)}" query_url = f"https://pixabay.com/api/videos/?{urlencode(params)}"
logger.info(f"searching videos: {query_url}, with proxies: {config.proxy}") logger.info(f"searching videos: {query_url}, with proxies: {config.proxy}")
try: try:
r = requests.get(query_url, proxies=config.proxy, verify=False, timeout=(30, 60)) r = requests.get(
query_url, proxies=config.proxy, verify=False, timeout=(30, 60)
)
response = r.json() response = r.json()
video_items = [] video_items = []
if "hits" not in response: if "hits" not in response:
@@ -155,7 +160,11 @@ def save_video(video_url: str, save_dir: str = "") -> str:
# if video does not exist, download it # if video does not exist, download it
with open(video_path, "wb") as f: with open(video_path, "wb") as f:
f.write(requests.get(video_url, proxies=config.proxy, verify=False, timeout=(60, 240)).content) f.write(
requests.get(
video_url, proxies=config.proxy, verify=False, timeout=(60, 240)
).content
)
if os.path.exists(video_path) and os.path.getsize(video_path) > 0: if os.path.exists(video_path) and os.path.getsize(video_path) > 0:
try: try:
@@ -174,14 +183,15 @@ def save_video(video_url: str, save_dir: str = "") -> str:
return "" return ""
def download_videos(task_id: str, def download_videos(
search_terms: List[str], task_id: str,
source: str = "pexels", search_terms: List[str],
video_aspect: VideoAspect = VideoAspect.portrait, source: str = "pexels",
video_contact_mode: VideoConcatMode = VideoConcatMode.random, video_aspect: VideoAspect = VideoAspect.portrait,
audio_duration: float = 0.0, video_contact_mode: VideoConcatMode = VideoConcatMode.random,
max_clip_duration: int = 5, audio_duration: float = 0.0,
) -> List[str]: max_clip_duration: int = 5,
) -> List[str]:
valid_video_items = [] valid_video_items = []
valid_video_urls = [] valid_video_urls = []
found_duration = 0.0 found_duration = 0.0
@@ -190,9 +200,11 @@ def download_videos(task_id: str,
search_videos = search_videos_pixabay search_videos = search_videos_pixabay
for search_term in search_terms: for search_term in search_terms:
video_items = search_videos(search_term=search_term, video_items = search_videos(
minimum_duration=max_clip_duration, search_term=search_term,
video_aspect=video_aspect) minimum_duration=max_clip_duration,
video_aspect=video_aspect,
)
logger.info(f"found {len(video_items)} videos for '{search_term}'") logger.info(f"found {len(video_items)} videos for '{search_term}'")
for item in video_items: for item in video_items:
@@ -202,7 +214,8 @@ def download_videos(task_id: str,
found_duration += item.duration found_duration += item.duration
logger.info( logger.info(
f"found total videos: {len(valid_video_items)}, required duration: {audio_duration} seconds, found duration: {found_duration} seconds") f"found total videos: {len(valid_video_items)}, required duration: {audio_duration} seconds, found duration: {found_duration} seconds"
)
video_paths = [] video_paths = []
material_directory = config.app.get("material_directory", "").strip() material_directory = config.app.get("material_directory", "").strip()
@@ -218,14 +231,18 @@ def download_videos(task_id: str,
for item in valid_video_items: for item in valid_video_items:
try: try:
logger.info(f"downloading video: {item.url}") logger.info(f"downloading video: {item.url}")
saved_video_path = save_video(video_url=item.url, save_dir=material_directory) saved_video_path = save_video(
video_url=item.url, save_dir=material_directory
)
if saved_video_path: if saved_video_path:
logger.info(f"video saved: {saved_video_path}") logger.info(f"video saved: {saved_video_path}")
video_paths.append(saved_video_path) video_paths.append(saved_video_path)
seconds = min(max_clip_duration, item.duration) seconds = min(max_clip_duration, item.duration)
total_duration += seconds total_duration += seconds
if total_duration > audio_duration: if total_duration > audio_duration:
logger.info(f"total duration of downloaded videos: {total_duration} seconds, skip downloading more") logger.info(
f"total duration of downloaded videos: {total_duration} seconds, skip downloading more"
)
break break
except Exception as e: except Exception as e:
logger.error(f"failed to download video: {utils.to_json(item)} => {str(e)}") logger.error(f"failed to download video: {utils.to_json(item)} => {str(e)}")
@@ -234,4 +251,6 @@ def download_videos(task_id: str,
if __name__ == "__main__": if __name__ == "__main__":
download_videos("test123", ["Money Exchange Medium"], audio_duration=100, source="pixabay") download_videos(
"test123", ["Money Exchange Medium"], audio_duration=100, source="pixabay"
)

View File

@@ -6,7 +6,6 @@ from app.models import const
# Base class for state management # Base class for state management
class BaseState(ABC): class BaseState(ABC):
@abstractmethod @abstractmethod
def update_task(self, task_id: str, state: int, progress: int = 0, **kwargs): def update_task(self, task_id: str, state: int, progress: int = 0, **kwargs):
pass pass
@@ -18,11 +17,16 @@ class BaseState(ABC):
# Memory state management # Memory state management
class MemoryState(BaseState): class MemoryState(BaseState):
def __init__(self): def __init__(self):
self._tasks = {} self._tasks = {}
def update_task(self, task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs): def update_task(
self,
task_id: str,
state: int = const.TASK_STATE_PROCESSING,
progress: int = 0,
**kwargs,
):
progress = int(progress) progress = int(progress)
if progress > 100: if progress > 100:
progress = 100 progress = 100
@@ -43,12 +47,18 @@ class MemoryState(BaseState):
# Redis state management # Redis state management
class RedisState(BaseState): class RedisState(BaseState):
def __init__(self, host="localhost", port=6379, db=0, password=None):
def __init__(self, host='localhost', port=6379, db=0, password=None):
import redis import redis
self._redis = redis.StrictRedis(host=host, port=port, db=db, password=password) self._redis = redis.StrictRedis(host=host, port=port, db=db, password=password)
def update_task(self, task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs): def update_task(
self,
task_id: str,
state: int = const.TASK_STATE_PROCESSING,
progress: int = 0,
**kwargs,
):
progress = int(progress) progress = int(progress)
if progress > 100: if progress > 100:
progress = 100 progress = 100
@@ -67,7 +77,10 @@ class RedisState(BaseState):
if not task_data: if not task_data:
return None return None
task = {key.decode('utf-8'): self._convert_to_original_type(value) for key, value in task_data.items()} task = {
key.decode("utf-8"): self._convert_to_original_type(value)
for key, value in task_data.items()
}
return task return task
def delete_task(self, task_id: str): def delete_task(self, task_id: str):
@@ -79,7 +92,7 @@ class RedisState(BaseState):
Convert the value from byte string to its original data type. Convert the value from byte string to its original data type.
You can extend this method to handle other data types as needed. You can extend this method to handle other data types as needed.
""" """
value_str = value.decode('utf-8') value_str = value.decode("utf-8")
try: try:
# try to convert byte string array to list # try to convert byte string array to list
@@ -100,4 +113,10 @@ _redis_port = config.app.get("redis_port", 6379)
_redis_db = config.app.get("redis_db", 0) _redis_db = config.app.get("redis_db", 0)
_redis_password = config.app.get("redis_password", None) _redis_password = config.app.get("redis_password", None)
state = RedisState(host=_redis_host, port=_redis_port, db=_redis_db, password=_redis_password) if _enable_redis else MemoryState() state = (
RedisState(
host=_redis_host, port=_redis_port, db=_redis_db, password=_redis_password
)
if _enable_redis
else MemoryState()
)

View File

@@ -23,18 +23,22 @@ def create(audio_file, subtitle_file: str = ""):
if not os.path.isdir(model_path) or not os.path.isfile(model_bin_file): if not os.path.isdir(model_path) or not os.path.isfile(model_bin_file):
model_path = model_size model_path = model_size
logger.info(f"loading model: {model_path}, device: {device}, compute_type: {compute_type}") logger.info(
f"loading model: {model_path}, device: {device}, compute_type: {compute_type}"
)
try: try:
model = WhisperModel(model_size_or_path=model_path, model = WhisperModel(
device=device, model_size_or_path=model_path, device=device, compute_type=compute_type
compute_type=compute_type) )
except Exception as e: except Exception as e:
logger.error(f"failed to load model: {e} \n\n" logger.error(
f"********************************************\n" f"failed to load model: {e} \n\n"
f"this may be caused by network issue. \n" f"********************************************\n"
f"please download the model manually and put it in the 'models' folder. \n" f"this may be caused by network issue. \n"
f"see [README.md FAQ](https://github.com/harry0703/MoneyPrinterTurbo) for more details.\n" f"please download the model manually and put it in the 'models' folder. \n"
f"********************************************\n\n") f"see [README.md FAQ](https://github.com/harry0703/MoneyPrinterTurbo) for more details.\n"
f"********************************************\n\n"
)
return None return None
logger.info(f"start, output file: {subtitle_file}") logger.info(f"start, output file: {subtitle_file}")
@@ -49,7 +53,9 @@ def create(audio_file, subtitle_file: str = ""):
vad_parameters=dict(min_silence_duration_ms=500), vad_parameters=dict(min_silence_duration_ms=500),
) )
logger.info(f"detected language: '{info.language}', probability: {info.language_probability:.2f}") logger.info(
f"detected language: '{info.language}', probability: {info.language_probability:.2f}"
)
start = timer() start = timer()
subtitles = [] subtitles = []
@@ -62,11 +68,9 @@ def create(audio_file, subtitle_file: str = ""):
msg = "[%.2fs -> %.2fs] %s" % (seg_start, seg_end, seg_text) msg = "[%.2fs -> %.2fs] %s" % (seg_start, seg_end, seg_text)
logger.debug(msg) logger.debug(msg)
subtitles.append({ subtitles.append(
"msg": seg_text, {"msg": seg_text, "start_time": seg_start, "end_time": seg_end}
"start_time": seg_start, )
"end_time": seg_end
})
for segment in segments: for segment in segments:
words_idx = 0 words_idx = 0
@@ -119,7 +123,11 @@ def create(audio_file, subtitle_file: str = ""):
for subtitle in subtitles: for subtitle in subtitles:
text = subtitle.get("msg") text = subtitle.get("msg")
if text: if text:
lines.append(utils.text_to_srt(idx, text, subtitle.get("start_time"), subtitle.get("end_time"))) lines.append(
utils.text_to_srt(
idx, text, subtitle.get("start_time"), subtitle.get("end_time")
)
)
idx += 1 idx += 1
sub = "\n".join(lines) + "\n" sub = "\n".join(lines) + "\n"
@@ -136,12 +144,12 @@ def file_to_subtitles(filename):
current_times = None current_times = None
current_text = "" current_text = ""
index = 0 index = 0
with open(filename, 'r', encoding="utf-8") as f: with open(filename, "r", encoding="utf-8") as f:
for line in f: for line in f:
times = re.findall("([0-9]*:[0-9]*:[0-9]*,[0-9]*)", line) times = re.findall("([0-9]*:[0-9]*:[0-9]*,[0-9]*)", line)
if times: if times:
current_times = line current_times = line
elif line.strip() == '' and current_times: elif line.strip() == "" and current_times:
index += 1 index += 1
times_texts.append((index, current_times.strip(), current_text.strip())) times_texts.append((index, current_times.strip(), current_text.strip()))
current_times, current_text = None, "" current_times, current_text = None, ""
@@ -169,6 +177,7 @@ def levenshtein_distance(s1, s2):
return previous_row[-1] return previous_row[-1]
def similarity(a, b): def similarity(a, b):
distance = levenshtein_distance(a.lower(), b.lower()) distance = levenshtein_distance(a.lower(), b.lower())
max_length = max(len(a), len(b)) max_length = max(len(a), len(b))
@@ -194,26 +203,44 @@ def correct(subtitle_file, video_script):
subtitle_index += 1 subtitle_index += 1
else: else:
combined_subtitle = subtitle_line combined_subtitle = subtitle_line
start_time = subtitle_items[subtitle_index][1].split(' --> ')[0] start_time = subtitle_items[subtitle_index][1].split(" --> ")[0]
end_time = subtitle_items[subtitle_index][1].split(' --> ')[1] end_time = subtitle_items[subtitle_index][1].split(" --> ")[1]
next_subtitle_index = subtitle_index + 1 next_subtitle_index = subtitle_index + 1
while next_subtitle_index < len(subtitle_items): while next_subtitle_index < len(subtitle_items):
next_subtitle = subtitle_items[next_subtitle_index][2].strip() next_subtitle = subtitle_items[next_subtitle_index][2].strip()
if similarity(script_line, combined_subtitle + " " + next_subtitle) > similarity(script_line, combined_subtitle): if similarity(
script_line, combined_subtitle + " " + next_subtitle
) > similarity(script_line, combined_subtitle):
combined_subtitle += " " + next_subtitle combined_subtitle += " " + next_subtitle
end_time = subtitle_items[next_subtitle_index][1].split(' --> ')[1] end_time = subtitle_items[next_subtitle_index][1].split(" --> ")[1]
next_subtitle_index += 1 next_subtitle_index += 1
else: else:
break break
if similarity(script_line, combined_subtitle) > 0.8: if similarity(script_line, combined_subtitle) > 0.8:
logger.warning(f"Merged/Corrected - Script: {script_line}, Subtitle: {combined_subtitle}") logger.warning(
new_subtitle_items.append((len(new_subtitle_items) + 1, f"{start_time} --> {end_time}", script_line)) f"Merged/Corrected - Script: {script_line}, Subtitle: {combined_subtitle}"
)
new_subtitle_items.append(
(
len(new_subtitle_items) + 1,
f"{start_time} --> {end_time}",
script_line,
)
)
corrected = True corrected = True
else: else:
logger.warning(f"Mismatch - Script: {script_line}, Subtitle: {combined_subtitle}") logger.warning(
new_subtitle_items.append((len(new_subtitle_items) + 1, f"{start_time} --> {end_time}", script_line)) f"Mismatch - Script: {script_line}, Subtitle: {combined_subtitle}"
)
new_subtitle_items.append(
(
len(new_subtitle_items) + 1,
f"{start_time} --> {end_time}",
script_line,
)
)
corrected = True corrected = True
script_index += 1 script_index += 1
@@ -223,10 +250,22 @@ def correct(subtitle_file, video_script):
while script_index < len(script_lines): while script_index < len(script_lines):
logger.warning(f"Extra script line: {script_lines[script_index]}") logger.warning(f"Extra script line: {script_lines[script_index]}")
if subtitle_index < len(subtitle_items): if subtitle_index < len(subtitle_items):
new_subtitle_items.append((len(new_subtitle_items) + 1, subtitle_items[subtitle_index][1], script_lines[script_index])) new_subtitle_items.append(
(
len(new_subtitle_items) + 1,
subtitle_items[subtitle_index][1],
script_lines[script_index],
)
)
subtitle_index += 1 subtitle_index += 1
else: else:
new_subtitle_items.append((len(new_subtitle_items) + 1, "00:00:00,000 --> 00:00:00,000", script_lines[script_index])) new_subtitle_items.append(
(
len(new_subtitle_items) + 1,
"00:00:00,000 --> 00:00:00,000",
script_lines[script_index],
)
)
script_index += 1 script_index += 1
corrected = True corrected = True

View File

@@ -988,7 +988,7 @@ Name: zh-CN-XiaoxiaoMultilingualNeural-V2
Gender: Female Gender: Female
""".strip() """.strip()
voices = [] voices = []
name = '' name = ""
for line in voices_str.split("\n"): for line in voices_str.split("\n"):
line = line.strip() line = line.strip()
if not line: if not line:
@@ -1008,7 +1008,7 @@ Gender: Female
voices.append(f"{name}-{gender}") voices.append(f"{name}-{gender}")
else: else:
voices.append(f"{name}-{gender}") voices.append(f"{name}-{gender}")
name = '' name = ""
voices.sort() voices.sort()
return voices return voices
@@ -1028,7 +1028,9 @@ def is_azure_v2_voice(voice_name: str):
return "" return ""
def tts(text: str, voice_name: str, voice_rate: float, voice_file: str) -> [SubMaker, None]: def tts(
text: str, voice_name: str, voice_rate: float, voice_file: str
) -> [SubMaker, None]:
if is_azure_v2_voice(voice_name): if is_azure_v2_voice(voice_name):
return azure_tts_v2(text, voice_name, voice_file) return azure_tts_v2(text, voice_name, voice_file)
return azure_tts_v1(text, voice_name, voice_rate, voice_file) return azure_tts_v1(text, voice_name, voice_rate, voice_file)
@@ -1044,7 +1046,9 @@ def convert_rate_to_percent(rate: float) -> str:
return f"{percent}%" return f"{percent}%"
def azure_tts_v1(text: str, voice_name: str, voice_rate: float, voice_file: str) -> [SubMaker, None]: def azure_tts_v1(
text: str, voice_name: str, voice_rate: float, voice_file: str
) -> [SubMaker, None]:
voice_name = parse_voice_name(voice_name) voice_name = parse_voice_name(voice_name)
text = text.strip() text = text.strip()
rate_str = convert_rate_to_percent(voice_rate) rate_str = convert_rate_to_percent(voice_rate)
@@ -1060,7 +1064,9 @@ def azure_tts_v1(text: str, voice_name: str, voice_rate: float, voice_file: str)
if chunk["type"] == "audio": if chunk["type"] == "audio":
file.write(chunk["data"]) file.write(chunk["data"])
elif chunk["type"] == "WordBoundary": elif chunk["type"] == "WordBoundary":
sub_maker.create_sub((chunk["offset"], chunk["duration"]), chunk["text"]) sub_maker.create_sub(
(chunk["offset"], chunk["duration"]), chunk["text"]
)
return sub_maker return sub_maker
sub_maker = asyncio.run(_do()) sub_maker = asyncio.run(_do())
@@ -1085,8 +1091,12 @@ def azure_tts_v2(text: str, voice_name: str, voice_file: str) -> [SubMaker, None
def _format_duration_to_offset(duration) -> int: def _format_duration_to_offset(duration) -> int:
if isinstance(duration, str): if isinstance(duration, str):
time_obj = datetime.strptime(duration, "%H:%M:%S.%f") time_obj = datetime.strptime(duration, "%H:%M:%S.%f")
milliseconds = (time_obj.hour * 3600000) + (time_obj.minute * 60000) + (time_obj.second * 1000) + ( milliseconds = (
time_obj.microsecond // 1000) (time_obj.hour * 3600000)
+ (time_obj.minute * 60000)
+ (time_obj.second * 1000)
+ (time_obj.microsecond // 1000)
)
return milliseconds * 10000 return milliseconds * 10000
if isinstance(duration, int): if isinstance(duration, int):
@@ -1119,20 +1129,29 @@ def azure_tts_v2(text: str, voice_name: str, voice_file: str) -> [SubMaker, None
# Creates an instance of a speech config with specified subscription key and service region. # Creates an instance of a speech config with specified subscription key and service region.
speech_key = config.azure.get("speech_key", "") speech_key = config.azure.get("speech_key", "")
service_region = config.azure.get("speech_region", "") service_region = config.azure.get("speech_region", "")
audio_config = speechsdk.audio.AudioOutputConfig(filename=voice_file, use_default_speaker=True) audio_config = speechsdk.audio.AudioOutputConfig(
speech_config = speechsdk.SpeechConfig(subscription=speech_key, filename=voice_file, use_default_speaker=True
region=service_region) )
speech_config = speechsdk.SpeechConfig(
subscription=speech_key, region=service_region
)
speech_config.speech_synthesis_voice_name = voice_name speech_config.speech_synthesis_voice_name = voice_name
# speech_config.set_property(property_id=speechsdk.PropertyId.SpeechServiceResponse_RequestSentenceBoundary, # speech_config.set_property(property_id=speechsdk.PropertyId.SpeechServiceResponse_RequestSentenceBoundary,
# value='true') # value='true')
speech_config.set_property(property_id=speechsdk.PropertyId.SpeechServiceResponse_RequestWordBoundary, speech_config.set_property(
value='true') property_id=speechsdk.PropertyId.SpeechServiceResponse_RequestWordBoundary,
value="true",
)
speech_config.set_speech_synthesis_output_format( speech_config.set_speech_synthesis_output_format(
speechsdk.SpeechSynthesisOutputFormat.Audio48Khz192KBitRateMonoMp3) speechsdk.SpeechSynthesisOutputFormat.Audio48Khz192KBitRateMonoMp3
speech_synthesizer = speechsdk.SpeechSynthesizer(audio_config=audio_config, )
speech_config=speech_config) speech_synthesizer = speechsdk.SpeechSynthesizer(
speech_synthesizer.synthesis_word_boundary.connect(speech_synthesizer_word_boundary_cb) audio_config=audio_config, speech_config=speech_config
)
speech_synthesizer.synthesis_word_boundary.connect(
speech_synthesizer_word_boundary_cb
)
result = speech_synthesizer.speak_text_async(text).get() result = speech_synthesizer.speak_text_async(text).get()
if result.reason == speechsdk.ResultReason.SynthesizingAudioCompleted: if result.reason == speechsdk.ResultReason.SynthesizingAudioCompleted:
@@ -1140,9 +1159,13 @@ def azure_tts_v2(text: str, voice_name: str, voice_file: str) -> [SubMaker, None
return sub_maker return sub_maker
elif result.reason == speechsdk.ResultReason.Canceled: elif result.reason == speechsdk.ResultReason.Canceled:
cancellation_details = result.cancellation_details cancellation_details = result.cancellation_details
logger.error(f"azure v2 speech synthesis canceled: {cancellation_details.reason}") logger.error(
f"azure v2 speech synthesis canceled: {cancellation_details.reason}"
)
if cancellation_details.reason == speechsdk.CancellationReason.Error: if cancellation_details.reason == speechsdk.CancellationReason.Error:
logger.error(f"azure v2 speech synthesis error: {cancellation_details.error_details}") logger.error(
f"azure v2 speech synthesis error: {cancellation_details.error_details}"
)
logger.info(f"completed, output file: {voice_file}") logger.info(f"completed, output file: {voice_file}")
except Exception as e: except Exception as e:
logger.error(f"failed, error: {str(e)}") logger.error(f"failed, error: {str(e)}")
@@ -1179,11 +1202,7 @@ def create_subtitle(sub_maker: submaker.SubMaker, text: str, subtitle_file: str)
""" """
start_t = mktimestamp(start_time).replace(".", ",") start_t = mktimestamp(start_time).replace(".", ",")
end_t = mktimestamp(end_time).replace(".", ",") end_t = mktimestamp(end_time).replace(".", ",")
return ( return f"{idx}\n" f"{start_t} --> {end_t}\n" f"{sub_text}\n"
f"{idx}\n"
f"{start_t} --> {end_t}\n"
f"{sub_text}\n"
)
start_time = -1.0 start_time = -1.0
sub_items = [] sub_items = []
@@ -1240,12 +1259,16 @@ def create_subtitle(sub_maker: submaker.SubMaker, text: str, subtitle_file: str)
try: try:
sbs = subtitles.file_to_subtitles(subtitle_file, encoding="utf-8") sbs = subtitles.file_to_subtitles(subtitle_file, encoding="utf-8")
duration = max([tb for ((ta, tb), txt) in sbs]) duration = max([tb for ((ta, tb), txt) in sbs])
logger.info(f"completed, subtitle file created: {subtitle_file}, duration: {duration}") logger.info(
f"completed, subtitle file created: {subtitle_file}, duration: {duration}"
)
except Exception as e: except Exception as e:
logger.error(f"failed, error: {str(e)}") logger.error(f"failed, error: {str(e)}")
os.remove(subtitle_file) os.remove(subtitle_file)
else: else:
logger.warning(f"failed, sub_items len: {len(sub_items)}, script_lines len: {len(script_lines)}") logger.warning(
f"failed, sub_items len: {len(sub_items)}, script_lines len: {len(script_lines)}"
)
except Exception as e: except Exception as e:
logger.error(f"failed, error: {str(e)}") logger.error(f"failed, error: {str(e)}")
@@ -1269,7 +1292,6 @@ if __name__ == "__main__":
voices = get_all_azure_voices() voices = get_all_azure_voices()
print(len(voices)) print(len(voices))
async def _do(): async def _do():
temp_dir = utils.storage_dir("temp") temp_dir = utils.storage_dir("temp")
@@ -1318,12 +1340,13 @@ if __name__ == "__main__":
for voice_name in voice_names: for voice_name in voice_names:
voice_file = f"{temp_dir}/tts-{voice_name}.mp3" voice_file = f"{temp_dir}/tts-{voice_name}.mp3"
subtitle_file = f"{temp_dir}/tts.mp3.srt" subtitle_file = f"{temp_dir}/tts.mp3.srt"
sub_maker = azure_tts_v2(text=text, voice_name=voice_name, voice_file=voice_file) sub_maker = azure_tts_v2(
text=text, voice_name=voice_name, voice_file=voice_file
)
create_subtitle(sub_maker=sub_maker, text=text, subtitle_file=subtitle_file) create_subtitle(sub_maker=sub_maker, text=text, subtitle_file=subtitle_file)
audio_duration = get_audio_duration(sub_maker) audio_duration = get_audio_duration(sub_maker)
print(f"voice: {voice_name}, audio duration: {audio_duration}s") print(f"voice: {voice_name}, audio duration: {audio_duration}s")
loop = asyncio.get_event_loop_policy().get_event_loop() loop = asyncio.get_event_loop_policy().get_event_loop()
try: try:
loop.run_until_complete(_do()) loop.run_until_complete(_do())

View File

@@ -15,12 +15,12 @@ urllib3.disable_warnings()
def get_response(status: int, data: Any = None, message: str = ""): def get_response(status: int, data: Any = None, message: str = ""):
obj = { obj = {
'status': status, "status": status,
} }
if data: if data:
obj['data'] = data obj["data"] = data
if message: if message:
obj['message'] = message obj["message"] = message
return obj return obj
@@ -41,7 +41,7 @@ def to_json(obj):
elif isinstance(o, (list, tuple)): elif isinstance(o, (list, tuple)):
return [serialize(item) for item in o] return [serialize(item) for item in o]
# 如果对象是自定义类型尝试返回其__dict__属性 # 如果对象是自定义类型尝试返回其__dict__属性
elif hasattr(o, '__dict__'): elif hasattr(o, "__dict__"):
return serialize(o.__dict__) return serialize(o.__dict__)
# 其他情况返回None或者可以选择抛出异常 # 其他情况返回None或者可以选择抛出异常
else: else:
@@ -199,7 +199,8 @@ def split_string_by_punctuations(s):
def md5(text): def md5(text):
import hashlib import hashlib
return hashlib.md5(text.encode('utf-8')).hexdigest()
return hashlib.md5(text.encode("utf-8")).hexdigest()
def get_system_locale(): def get_system_locale():

View File

@@ -12,6 +12,6 @@ build_and_render(
parse_refs=False, parse_refs=False,
sections=["build", "deps", "feat", "fix", "refactor"], sections=["build", "deps", "feat", "fix", "refactor"],
versioning="pep440", versioning="pep440",
bump="1.1.2", # 指定bump版本 bump="1.1.2", # 指定bump版本
in_place=True, in_place=True,
) )