Format project code

This commit is contained in:
yyhhyyyyyy
2024-07-24 14:59:06 +08:00
parent bbd4e94941
commit 905841965a
18 changed files with 410 additions and 214 deletions

View File

@@ -7,14 +7,14 @@ from app.models.exception import HttpException
def get_task_id(request: Request):
task_id = request.headers.get('x-task-id')
task_id = request.headers.get("x-task-id")
if not task_id:
task_id = uuid4()
return str(task_id)
def get_api_key(request: Request):
api_key = request.headers.get('x-api-key')
api_key = request.headers.get("x-api-key")
return api_key
@@ -23,5 +23,9 @@ def verify_token(request: Request):
if token != config.app.get("api_key", ""):
request_id = get_task_id(request)
request_url = request.url
user_agent = request.headers.get('user-agent')
raise HttpException(task_id=request_id, status_code=401, message=f"invalid token: {request_url}, {user_agent}")
user_agent = request.headers.get("user-agent")
raise HttpException(
task_id=request_id,
status_code=401,
message=f"invalid token: {request_url}, {user_agent}",
)

View File

@@ -18,11 +18,15 @@ class TaskManager:
print(f"add task: {func.__name__}, current_tasks: {self.current_tasks}")
self.execute_task(func, *args, **kwargs)
else:
print(f"enqueue task: {func.__name__}, current_tasks: {self.current_tasks}")
print(
f"enqueue task: {func.__name__}, current_tasks: {self.current_tasks}"
)
self.enqueue({"func": func, "args": args, "kwargs": kwargs})
def execute_task(self, func: Callable, *args: Any, **kwargs: Any):
thread = threading.Thread(target=self.run_task, args=(func, *args), kwargs=kwargs)
thread = threading.Thread(
target=self.run_task, args=(func, *args), kwargs=kwargs
)
thread.start()
def run_task(self, func: Callable, *args: Any, **kwargs: Any):
@@ -35,11 +39,14 @@ class TaskManager:
def check_queue(self):
with self.lock:
if self.current_tasks < self.max_concurrent_tasks and not self.is_queue_empty():
if (
self.current_tasks < self.max_concurrent_tasks
and not self.is_queue_empty()
):
task_info = self.dequeue()
func = task_info['func']
args = task_info.get('args', ())
kwargs = task_info.get('kwargs', {})
func = task_info["func"]
args = task_info.get("args", ())
kwargs = task_info.get("kwargs", {})
self.execute_task(func, *args, **kwargs)
def task_done(self):

View File

@@ -8,7 +8,7 @@ from app.models.schema import VideoParams
from app.services import task as tm
FUNC_MAP = {
'start': tm.start,
"start": tm.start,
# 'start_test': tm.start_test
}
@@ -24,11 +24,15 @@ class RedisTaskManager(TaskManager):
def enqueue(self, task: Dict):
task_with_serializable_params = task.copy()
if 'params' in task['kwargs'] and isinstance(task['kwargs']['params'], VideoParams):
task_with_serializable_params['kwargs']['params'] = task['kwargs']['params'].dict()
if "params" in task["kwargs"] and isinstance(
task["kwargs"]["params"], VideoParams
):
task_with_serializable_params["kwargs"]["params"] = task["kwargs"][
"params"
].dict()
# 将函数对象转换为其名称
task_with_serializable_params['func'] = task['func'].__name__
task_with_serializable_params["func"] = task["func"].__name__
self.redis_client.rpush(self.queue, json.dumps(task_with_serializable_params))
def dequeue(self):
@@ -36,10 +40,14 @@ class RedisTaskManager(TaskManager):
if task_json:
task_info = json.loads(task_json)
# 将函数名称转换回函数对象
task_info['func'] = FUNC_MAP[task_info['func']]
task_info["func"] = FUNC_MAP[task_info["func"]]
if 'params' in task_info['kwargs'] and isinstance(task_info['kwargs']['params'], dict):
task_info['kwargs']['params'] = VideoParams(**task_info['kwargs']['params'])
if "params" in task_info["kwargs"] and isinstance(
task_info["kwargs"]["params"], dict
):
task_info["kwargs"]["params"] = VideoParams(
**task_info["kwargs"]["params"]
)
return task_info
return None

View File

@@ -4,6 +4,11 @@ from fastapi import Request
router = APIRouter()
@router.get("/ping", tags=["Health Check"], description="检查服务可用性", response_description="pong")
@router.get(
"/ping",
tags=["Health Check"],
description="检查服务可用性",
response_description="pong",
)
def ping(request: Request) -> str:
return "pong"

View File

@@ -3,8 +3,8 @@ from fastapi import APIRouter, Depends
def new_router(dependencies=None):
router = APIRouter()
router.tags = ['V1']
router.prefix = '/api/v1'
router.tags = ["V1"]
router.prefix = "/api/v1"
# 将认证依赖项应用于所有路由
if dependencies:
router.dependencies = dependencies

View File

@@ -1,6 +1,11 @@
from fastapi import Request
from app.controllers.v1.base import new_router
from app.models.schema import VideoScriptResponse, VideoScriptRequest, VideoTermsResponse, VideoTermsRequest
from app.models.schema import (
VideoScriptResponse,
VideoScriptRequest,
VideoTermsResponse,
VideoTermsRequest,
)
from app.services import llm
from app.utils import utils
@@ -9,23 +14,31 @@ from app.utils import utils
router = new_router()
@router.post("/scripts", response_model=VideoScriptResponse, summary="Create a script for the video")
@router.post(
"/scripts",
response_model=VideoScriptResponse,
summary="Create a script for the video",
)
def generate_video_script(request: Request, body: VideoScriptRequest):
video_script = llm.generate_script(video_subject=body.video_subject,
language=body.video_language,
paragraph_number=body.paragraph_number)
response = {
"video_script": video_script
}
video_script = llm.generate_script(
video_subject=body.video_subject,
language=body.video_language,
paragraph_number=body.paragraph_number,
)
response = {"video_script": video_script}
return utils.get_response(200, response)
@router.post("/terms", response_model=VideoTermsResponse, summary="Generate video terms based on the video script")
@router.post(
"/terms",
response_model=VideoTermsResponse,
summary="Generate video terms based on the video script",
)
def generate_video_terms(request: Request, body: VideoTermsRequest):
video_terms = llm.generate_terms(video_subject=body.video_subject,
video_script=body.video_script,
amount=body.amount)
response = {
"video_terms": video_terms
}
video_terms = llm.generate_terms(
video_subject=body.video_subject,
video_script=body.video_script,
amount=body.amount,
)
response = {"video_terms": video_terms}
return utils.get_response(200, response)