chore: add stream support for video

This commit is contained in:
kevin.zhang
2024-04-12 17:43:21 +08:00
parent 740960b0ff
commit 1fb3399b02
3 changed files with 17 additions and 4 deletions

View File

@@ -21,3 +21,4 @@ __pycache__/
.svn/ .svn/
storage/ storage/
config.toml

View File

@@ -3,6 +3,7 @@ import glob
import shutil import shutil
from fastapi import Request, Depends, Path, BackgroundTasks, UploadFile from fastapi import Request, Depends, Path, BackgroundTasks, UploadFile
from fastapi.responses import FileResponse
from fastapi.params import File from fastapi.params import File
from loguru import logger from loguru import logger
@@ -78,7 +79,7 @@ def get_task(request: Request, task_id: str = Path(..., description="Task ID"),
@router.delete("/tasks/{task_id}", response_model=TaskDeletionResponse, summary="Delete a generated short video task") @router.delete("/tasks/{task_id}", response_model=TaskDeletionResponse, summary="Delete a generated short video task")
def create_video(request: Request, task_id: str = Path(..., description="Task ID")): def delete_video(request: Request, task_id: str = Path(..., description="Task ID")):
request_id = base.get_task_id(request) request_id = base.get_task_id(request)
task = sm.state.get_task(task_id) task = sm.state.get_task(task_id)
if task: if task:
@@ -130,3 +131,13 @@ def upload_bgm_file(request: Request, file: UploadFile = File(...)):
return utils.get_response(200, response) return utils.get_response(200, response)
raise HttpException('', status_code=400, message=f"{request_id}: Only *.mp3 files can be uploaded") raise HttpException('', status_code=400, message=f"{request_id}: Only *.mp3 files can be uploaded")
@router.get("/stream/{file_path:path}")
async def stream_video(request: Request, file_path: str):
tasks_dir = utils.task_dir()
video_path = os.path.join(tasks_dir, file_path)
if os.path.isfile(video_path):
return FileResponse(video_path, media_type="video/mp4", filename=file_path)
else:
return {"message": "File not found."}

View File

@@ -44,9 +44,9 @@ class MemoryState(BaseState):
# Redis state management # Redis state management
class RedisState(BaseState): class RedisState(BaseState):
def __init__(self, host='localhost', port=6379, db=0): def __init__(self, host='localhost', port=6379, db=0, password=None):
import redis import redis
self._redis = redis.StrictRedis(host=host, port=port, db=db) self._redis = redis.StrictRedis(host=host, port=port, db=db, password=password)
def update_task(self, task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs): def update_task(self, task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs):
progress = int(progress) progress = int(progress)
@@ -98,5 +98,6 @@ _enable_redis = config.app.get("enable_redis", False)
_redis_host = config.app.get("redis_host", "localhost") _redis_host = config.app.get("redis_host", "localhost")
_redis_port = config.app.get("redis_port", 6379) _redis_port = config.app.get("redis_port", 6379)
_redis_db = config.app.get("redis_db", 0) _redis_db = config.app.get("redis_db", 0)
_redis_password = config.app.get("redis_password", None)
state = RedisState(host=_redis_host, port=_redis_port, db=_redis_db) if _enable_redis else MemoryState() state = RedisState(host=_redis_host, port=_redis_port, db=_redis_db, password=_redis_password) if _enable_redis else MemoryState()