Merge branch 'harry0703:main' into main
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import socket
|
import socket
|
||||||
import tomli
|
import toml
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
|
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
|
||||||
@@ -16,17 +16,17 @@ if not os.path.isfile(config_file):
|
|||||||
logger.info(f"load config from file: {config_file}")
|
logger.info(f"load config from file: {config_file}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(config_file, mode="rb") as fp:
|
_cfg = toml.load(config_file)
|
||||||
_cfg = tomli.load(fp)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"load config failed: {str(e)}, try to load as utf-8-sig")
|
logger.warning(f"load config failed: {str(e)}, try to load as utf-8-sig")
|
||||||
with open(config_file, mode="r", encoding='utf-8-sig') as fp:
|
with open(config_file, mode="r", encoding='utf-8-sig') as fp:
|
||||||
_cfg_content = fp.read()
|
_cfg_content = fp.read()
|
||||||
_cfg = tomli.loads(_cfg_content)
|
_cfg = toml.loads(_cfg_content)
|
||||||
|
|
||||||
app = _cfg.get("app", {})
|
app = _cfg.get("app", {})
|
||||||
whisper = _cfg.get("whisper", {})
|
whisper = _cfg.get("whisper", {})
|
||||||
pexels = _cfg.get("pexels", {})
|
pexels = _cfg.get("pexels", {})
|
||||||
|
ui = _cfg.get("ui", {})
|
||||||
|
|
||||||
hostname = socket.gethostname()
|
hostname = socket.gethostname()
|
||||||
|
|
||||||
@@ -47,9 +47,18 @@ ffmpeg_path = app.get("ffmpeg_path", "")
|
|||||||
if ffmpeg_path and os.path.isfile(ffmpeg_path):
|
if ffmpeg_path and os.path.isfile(ffmpeg_path):
|
||||||
os.environ["IMAGEIO_FFMPEG_EXE"] = ffmpeg_path
|
os.environ["IMAGEIO_FFMPEG_EXE"] = ffmpeg_path
|
||||||
|
|
||||||
|
|
||||||
# __cfg = {
|
# __cfg = {
|
||||||
# "hostname": hostname,
|
# "hostname": hostname,
|
||||||
# "listen_host": listen_host,
|
# "listen_host": listen_host,
|
||||||
# "listen_port": listen_port,
|
# "listen_port": listen_port,
|
||||||
# }
|
# }
|
||||||
# logger.info(__cfg)
|
# logger.info(__cfg)
|
||||||
|
|
||||||
|
|
||||||
|
def save_config():
|
||||||
|
with open(config_file, "w", encoding="utf-8") as f:
|
||||||
|
_cfg["app"] = app
|
||||||
|
_cfg["whisper"] = whisper
|
||||||
|
_cfg["pexels"] = pexels
|
||||||
|
f.write(toml.dumps(_cfg))
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ def create_video(background_tasks: BackgroundTasks, request: Request, body: Task
|
|||||||
"request_id": request_id,
|
"request_id": request_id,
|
||||||
"params": body.dict(),
|
"params": body.dict(),
|
||||||
}
|
}
|
||||||
sm.update_task(task_id)
|
sm.state.update_task(task_id)
|
||||||
background_tasks.add_task(tm.start, task_id=task_id, params=body)
|
background_tasks.add_task(tm.start, task_id=task_id, params=body)
|
||||||
logger.success(f"video created: {utils.to_json(task)}")
|
logger.success(f"video created: {utils.to_json(task)}")
|
||||||
return utils.get_response(200, task)
|
return utils.get_response(200, task)
|
||||||
@@ -46,7 +46,7 @@ def get_task(request: Request, task_id: str = Path(..., description="Task ID"),
|
|||||||
endpoint = endpoint.rstrip("/")
|
endpoint = endpoint.rstrip("/")
|
||||||
|
|
||||||
request_id = base.get_task_id(request)
|
request_id = base.get_task_id(request)
|
||||||
task = sm.get_task(task_id)
|
task = sm.state.get_task(task_id)
|
||||||
if task:
|
if task:
|
||||||
task_dir = utils.task_dir()
|
task_dir = utils.task_dir()
|
||||||
|
|
||||||
|
|||||||
@@ -5,9 +5,9 @@ from typing import List
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from openai import AzureOpenAI
|
from openai import AzureOpenAI
|
||||||
import google.generativeai as genai
|
|
||||||
from app.config import config
|
from app.config import config
|
||||||
|
|
||||||
|
|
||||||
def _generate_response(prompt: str) -> str:
|
def _generate_response(prompt: str) -> str:
|
||||||
content = ""
|
content = ""
|
||||||
llm_provider = config.app.get("llm_provider", "openai")
|
llm_provider = config.app.get("llm_provider", "openai")
|
||||||
@@ -78,6 +78,7 @@ def _generate_response(prompt: str) -> str:
|
|||||||
return content.replace("\n", "")
|
return content.replace("\n", "")
|
||||||
|
|
||||||
if llm_provider == "gemini":
|
if llm_provider == "gemini":
|
||||||
|
import google.generativeai as genai
|
||||||
genai.configure(api_key=api_key)
|
genai.configure(api_key=api_key)
|
||||||
|
|
||||||
generation_config = {
|
generation_config = {
|
||||||
|
|||||||
@@ -11,13 +11,14 @@ from app.models.schema import VideoAspect, VideoConcatMode, MaterialInfo
|
|||||||
from app.utils import utils
|
from app.utils import utils
|
||||||
|
|
||||||
requested_count = 0
|
requested_count = 0
|
||||||
pexels_api_keys = config.app.get("pexels_api_keys")
|
|
||||||
if not pexels_api_keys:
|
|
||||||
raise ValueError(
|
|
||||||
f"\n\n##### pexels_api_keys is not set #####\n\nPlease set it in the config.toml file: {config.config_file}\n\n{utils.to_json(config.app)}")
|
|
||||||
|
|
||||||
|
|
||||||
def round_robin_api_key():
|
def round_robin_api_key():
|
||||||
|
pexels_api_keys = config.app.get("pexels_api_keys")
|
||||||
|
if not pexels_api_keys:
|
||||||
|
raise ValueError(
|
||||||
|
f"\n\n##### pexels_api_keys is not set #####\n\nPlease set it in the config.toml file: {config.config_file}\n\n{utils.to_json(config.app)}")
|
||||||
|
|
||||||
# if only one key is provided, return it
|
# if only one key is provided, return it
|
||||||
if isinstance(pexels_api_keys, str):
|
if isinstance(pexels_api_keys, str):
|
||||||
return pexels_api_keys
|
return pexels_api_keys
|
||||||
|
|||||||
@@ -1,35 +1,96 @@
|
|||||||
# State Management
|
import ast
|
||||||
# This module is responsible for managing the state of the application.
|
import json
|
||||||
import math
|
from abc import ABC, abstractmethod
|
||||||
|
import redis
|
||||||
# 如果你部署在分布式环境中,你可能需要一个中心化的状态管理服务,比如 Redis 或者数据库。
|
from app.config import config
|
||||||
# 如果你的应用程序是单机的,你可以使用内存来存储状态。
|
|
||||||
|
|
||||||
# If you are deploying in a distributed environment, you might need a centralized state management service like Redis or a database.
|
|
||||||
# If your application is single-node, you can use memory to store the state.
|
|
||||||
|
|
||||||
from app.models import const
|
from app.models import const
|
||||||
from app.utils import utils
|
|
||||||
|
|
||||||
_tasks = {}
|
|
||||||
|
|
||||||
|
|
||||||
def update_task(task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs):
|
# Base class for state management
|
||||||
"""
|
class BaseState(ABC):
|
||||||
Set the state of the task.
|
|
||||||
"""
|
@abstractmethod
|
||||||
|
def update_task(self, task_id: str, state: int, progress: int = 0, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_task(self, task_id: str):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Memory state management
|
||||||
|
class MemoryState(BaseState):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._tasks = {}
|
||||||
|
|
||||||
|
def update_task(self, task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs):
|
||||||
progress = int(progress)
|
progress = int(progress)
|
||||||
if progress > 100:
|
if progress > 100:
|
||||||
progress = 100
|
progress = 100
|
||||||
|
|
||||||
_tasks[task_id] = {
|
self._tasks[task_id] = {
|
||||||
"state": state,
|
"state": state,
|
||||||
"progress": progress,
|
"progress": progress,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_task(task_id: str):
|
def get_task(self, task_id: str):
|
||||||
|
return self._tasks.get(task_id, None)
|
||||||
|
|
||||||
|
|
||||||
|
# Redis state management
|
||||||
|
class RedisState(BaseState):
|
||||||
|
|
||||||
|
def __init__(self, host='localhost', port=6379, db=0):
|
||||||
|
self._redis = redis.StrictRedis(host=host, port=port, db=db)
|
||||||
|
|
||||||
|
def update_task(self, task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs):
|
||||||
|
progress = int(progress)
|
||||||
|
if progress > 100:
|
||||||
|
progress = 100
|
||||||
|
|
||||||
|
fields = {
|
||||||
|
"state": state,
|
||||||
|
"progress": progress,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
for field, value in fields.items():
|
||||||
|
self._redis.hset(task_id, field, str(value))
|
||||||
|
|
||||||
|
def get_task(self, task_id: str):
|
||||||
|
task_data = self._redis.hgetall(task_id)
|
||||||
|
if not task_data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
task = {key.decode('utf-8'): self._convert_to_original_type(value) for key, value in task_data.items()}
|
||||||
|
return task
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_to_original_type(value):
|
||||||
"""
|
"""
|
||||||
Get the state of the task.
|
Convert the value from byte string to its original data type.
|
||||||
|
You can extend this method to handle other data types as needed.
|
||||||
"""
|
"""
|
||||||
return _tasks.get(task_id, None)
|
value_str = value.decode('utf-8')
|
||||||
|
|
||||||
|
try:
|
||||||
|
# try to convert byte string array to list
|
||||||
|
return ast.literal_eval(value_str)
|
||||||
|
except (ValueError, SyntaxError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
if value_str.isdigit():
|
||||||
|
return int(value_str)
|
||||||
|
# Add more conversions here if needed
|
||||||
|
return value_str
|
||||||
|
|
||||||
|
|
||||||
|
# Global state
|
||||||
|
_enable_redis = config.app.get("enable_redis", False)
|
||||||
|
_redis_host = config.app.get("redis_host", "localhost")
|
||||||
|
_redis_port = config.app.get("redis_port", 6379)
|
||||||
|
_redis_db = config.app.get("redis_db", 0)
|
||||||
|
|
||||||
|
state = RedisState(host=_redis_host, port=_redis_port, db=_redis_db) if _enable_redis else MemoryState()
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ def start(task_id, params: VideoParams):
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
logger.info(f"start task: {task_id}")
|
logger.info(f"start task: {task_id}")
|
||||||
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=5)
|
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=5)
|
||||||
|
|
||||||
video_subject = params.video_subject
|
video_subject = params.video_subject
|
||||||
voice_name = voice.parse_voice_name(params.voice_name)
|
voice_name = voice.parse_voice_name(params.voice_name)
|
||||||
@@ -44,7 +44,7 @@ def start(task_id, params: VideoParams):
|
|||||||
else:
|
else:
|
||||||
logger.debug(f"video script: \n{video_script}")
|
logger.debug(f"video script: \n{video_script}")
|
||||||
|
|
||||||
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=10)
|
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=10)
|
||||||
|
|
||||||
logger.info("\n\n## generating video terms")
|
logger.info("\n\n## generating video terms")
|
||||||
video_terms = params.video_terms
|
video_terms = params.video_terms
|
||||||
@@ -70,13 +70,13 @@ def start(task_id, params: VideoParams):
|
|||||||
with open(script_file, "w", encoding="utf-8") as f:
|
with open(script_file, "w", encoding="utf-8") as f:
|
||||||
f.write(utils.to_json(script_data))
|
f.write(utils.to_json(script_data))
|
||||||
|
|
||||||
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=20)
|
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=20)
|
||||||
|
|
||||||
logger.info("\n\n## generating audio")
|
logger.info("\n\n## generating audio")
|
||||||
audio_file = path.join(utils.task_dir(task_id), f"audio.mp3")
|
audio_file = path.join(utils.task_dir(task_id), f"audio.mp3")
|
||||||
sub_maker = voice.tts(text=video_script, voice_name=voice_name, voice_file=audio_file)
|
sub_maker = voice.tts(text=video_script, voice_name=voice_name, voice_file=audio_file)
|
||||||
if sub_maker is None:
|
if sub_maker is None:
|
||||||
sm.update_task(task_id, state=const.TASK_STATE_FAILED)
|
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
|
||||||
logger.error(
|
logger.error(
|
||||||
"failed to generate audio, maybe the network is not available. if you are in China, please use a VPN.")
|
"failed to generate audio, maybe the network is not available. if you are in China, please use a VPN.")
|
||||||
return
|
return
|
||||||
@@ -84,7 +84,7 @@ def start(task_id, params: VideoParams):
|
|||||||
audio_duration = voice.get_audio_duration(sub_maker)
|
audio_duration = voice.get_audio_duration(sub_maker)
|
||||||
audio_duration = math.ceil(audio_duration)
|
audio_duration = math.ceil(audio_duration)
|
||||||
|
|
||||||
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=30)
|
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=30)
|
||||||
|
|
||||||
subtitle_path = ""
|
subtitle_path = ""
|
||||||
if params.subtitle_enabled:
|
if params.subtitle_enabled:
|
||||||
@@ -108,7 +108,7 @@ def start(task_id, params: VideoParams):
|
|||||||
logger.warning(f"subtitle file is invalid: {subtitle_path}")
|
logger.warning(f"subtitle file is invalid: {subtitle_path}")
|
||||||
subtitle_path = ""
|
subtitle_path = ""
|
||||||
|
|
||||||
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=40)
|
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=40)
|
||||||
|
|
||||||
logger.info("\n\n## downloading videos")
|
logger.info("\n\n## downloading videos")
|
||||||
downloaded_videos = material.download_videos(task_id=task_id,
|
downloaded_videos = material.download_videos(task_id=task_id,
|
||||||
@@ -119,12 +119,12 @@ def start(task_id, params: VideoParams):
|
|||||||
max_clip_duration=max_clip_duration,
|
max_clip_duration=max_clip_duration,
|
||||||
)
|
)
|
||||||
if not downloaded_videos:
|
if not downloaded_videos:
|
||||||
sm.update_task(task_id, state=const.TASK_STATE_FAILED)
|
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
|
||||||
logger.error(
|
logger.error(
|
||||||
"failed to download videos, maybe the network is not available. if you are in China, please use a VPN.")
|
"failed to download videos, maybe the network is not available. if you are in China, please use a VPN.")
|
||||||
return
|
return
|
||||||
|
|
||||||
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=50)
|
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=50)
|
||||||
|
|
||||||
final_video_paths = []
|
final_video_paths = []
|
||||||
combined_video_paths = []
|
combined_video_paths = []
|
||||||
@@ -146,7 +146,7 @@ def start(task_id, params: VideoParams):
|
|||||||
threads=n_threads)
|
threads=n_threads)
|
||||||
|
|
||||||
_progress += 50 / params.video_count / 2
|
_progress += 50 / params.video_count / 2
|
||||||
sm.update_task(task_id, progress=_progress)
|
sm.state.update_task(task_id, progress=_progress)
|
||||||
|
|
||||||
final_video_path = path.join(utils.task_dir(task_id), f"final-{index}.mp4")
|
final_video_path = path.join(utils.task_dir(task_id), f"final-{index}.mp4")
|
||||||
|
|
||||||
@@ -160,7 +160,7 @@ def start(task_id, params: VideoParams):
|
|||||||
)
|
)
|
||||||
|
|
||||||
_progress += 50 / params.video_count / 2
|
_progress += 50 / params.video_count / 2
|
||||||
sm.update_task(task_id, progress=_progress)
|
sm.state.update_task(task_id, progress=_progress)
|
||||||
|
|
||||||
final_video_paths.append(final_video_path)
|
final_video_paths.append(final_video_path)
|
||||||
combined_video_paths.append(combined_video_path)
|
combined_video_paths.append(combined_video_path)
|
||||||
@@ -171,5 +171,5 @@ def start(task_id, params: VideoParams):
|
|||||||
"videos": final_video_paths,
|
"videos": final_video_paths,
|
||||||
"combined_videos": combined_video_paths
|
"combined_videos": combined_video_paths
|
||||||
}
|
}
|
||||||
sm.update_task(task_id, state=const.TASK_STATE_COMPLETE, progress=100, **kwargs)
|
sm.state.update_task(task_id, state=const.TASK_STATE_COMPLETE, progress=100, **kwargs)
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import locale
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
import threading
|
import threading
|
||||||
@@ -174,3 +175,25 @@ def split_string_by_punctuations(s):
|
|||||||
def md5(text):
|
def md5(text):
|
||||||
import hashlib
|
import hashlib
|
||||||
return hashlib.md5(text.encode('utf-8')).hexdigest()
|
return hashlib.md5(text.encode('utf-8')).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def get_system_locale():
|
||||||
|
try:
|
||||||
|
loc = locale.getdefaultlocale()
|
||||||
|
# zh_CN, zh_TW return zh
|
||||||
|
# en_US, en_GB return en
|
||||||
|
language_code = loc[0].split("_")[0]
|
||||||
|
return language_code
|
||||||
|
except Exception as e:
|
||||||
|
return "en"
|
||||||
|
|
||||||
|
|
||||||
|
def load_locales(i18n_dir):
|
||||||
|
_locales = {}
|
||||||
|
for root, dirs, files in os.walk(i18n_dir):
|
||||||
|
for file in files:
|
||||||
|
if file.endswith(".json"):
|
||||||
|
lang = file.split(".")[0]
|
||||||
|
with open(os.path.join(root, file), "r", encoding="utf-8") as f:
|
||||||
|
_locales[lang] = json.loads(f.read())
|
||||||
|
return _locales
|
||||||
|
|||||||
@@ -129,6 +129,11 @@
|
|||||||
|
|
||||||
material_directory = ""
|
material_directory = ""
|
||||||
|
|
||||||
|
# Used for state management of the task
|
||||||
|
enable_redis = false
|
||||||
|
redis_host = "localhost"
|
||||||
|
redis_port = 6379
|
||||||
|
redis_db = 0
|
||||||
|
|
||||||
[whisper]
|
[whisper]
|
||||||
# Only effective when subtitle_provider is "whisper"
|
# Only effective when subtitle_provider is "whisper"
|
||||||
|
|||||||
@@ -16,3 +16,4 @@ g4f~=0.2.5.4
|
|||||||
dashscope~=1.15.0
|
dashscope~=1.15.0
|
||||||
google.generativeai~=0.4.1
|
google.generativeai~=0.4.1
|
||||||
python-multipart~=0.0.9
|
python-multipart~=0.0.9
|
||||||
|
redis==5.0.3
|
||||||
@@ -1,4 +1,2 @@
|
|||||||
set CURRENT_DIR=%CD%
|
|
||||||
set PYTHONPATH=%CURRENT_DIR%
|
|
||||||
rem set HF_ENDPOINT=https://hf-mirror.com
|
rem set HF_ENDPOINT=https://hf-mirror.com
|
||||||
streamlit run .\webui\Main.py
|
streamlit run .\webui\Main.py --browser.gatherUsageStats=False --server.enableCORS=True
|
||||||
4
webui.sh
4
webui.sh
@@ -1,7 +1,3 @@
|
|||||||
CURRENT_DIR=$(pwd)
|
|
||||||
echo "***** Current directory: $CURRENT_DIR *****"
|
|
||||||
export PYTHONPATH="${CURRENT_DIR}:$PYTHONPATH"
|
|
||||||
|
|
||||||
# If you could not download the model from the official site, you can use the mirror site.
|
# If you could not download the model from the official site, you can use the mirror site.
|
||||||
# Just remove the comment of the following line .
|
# Just remove the comment of the following line .
|
||||||
# 如果你无法从官方网站下载模型,你可以使用镜像网站。
|
# 如果你无法从官方网站下载模型,你可以使用镜像网站。
|
||||||
|
|||||||
152
webui/Main.py
152
webui/Main.py
@@ -9,15 +9,12 @@ if root_dir not in sys.path:
|
|||||||
print(sys.path)
|
print(sys.path)
|
||||||
print("")
|
print("")
|
||||||
|
|
||||||
import json
|
|
||||||
import locale
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
import platform
|
import platform
|
||||||
import streamlit.components.v1 as components
|
import streamlit.components.v1 as components
|
||||||
import toml
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
st.set_page_config(page_title="MoneyPrinterTurbo",
|
st.set_page_config(page_title="MoneyPrinterTurbo",
|
||||||
@@ -35,6 +32,7 @@ st.set_page_config(page_title="MoneyPrinterTurbo",
|
|||||||
from app.models.schema import VideoParams, VideoAspect, VideoConcatMode
|
from app.models.schema import VideoParams, VideoAspect, VideoConcatMode
|
||||||
from app.services import task as tm, llm, voice
|
from app.services import task as tm, llm, voice
|
||||||
from app.utils import utils
|
from app.utils import utils
|
||||||
|
from app.config import config
|
||||||
|
|
||||||
hide_streamlit_style = """
|
hide_streamlit_style = """
|
||||||
<style>#root > div:nth-child(1) > div > div > div > div > section > div {padding-top: 0rem;}</style>
|
<style>#root > div:nth-child(1) > div > div > div > div > section > div {padding-top: 0rem;}</style>
|
||||||
@@ -46,33 +44,7 @@ font_dir = os.path.join(root_dir, "resource", "fonts")
|
|||||||
song_dir = os.path.join(root_dir, "resource", "songs")
|
song_dir = os.path.join(root_dir, "resource", "songs")
|
||||||
i18n_dir = os.path.join(root_dir, "webui", "i18n")
|
i18n_dir = os.path.join(root_dir, "webui", "i18n")
|
||||||
config_file = os.path.join(root_dir, "webui", ".streamlit", "webui.toml")
|
config_file = os.path.join(root_dir, "webui", ".streamlit", "webui.toml")
|
||||||
|
system_locale = utils.get_system_locale()
|
||||||
|
|
||||||
def load_config() -> dict:
|
|
||||||
try:
|
|
||||||
return toml.load(config_file)
|
|
||||||
except Exception as e:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
cfg = load_config()
|
|
||||||
|
|
||||||
|
|
||||||
def save_config():
|
|
||||||
with open(config_file, "w", encoding="utf-8") as f:
|
|
||||||
f.write(toml.dumps(cfg))
|
|
||||||
|
|
||||||
|
|
||||||
def get_system_locale():
|
|
||||||
try:
|
|
||||||
loc = locale.getdefaultlocale()
|
|
||||||
# zh_CN, zh_TW return zh
|
|
||||||
# en_US, en_GB return en
|
|
||||||
language_code = loc[0].split("_")[0]
|
|
||||||
return language_code
|
|
||||||
except Exception as e:
|
|
||||||
return "en"
|
|
||||||
|
|
||||||
|
|
||||||
if 'video_subject' not in st.session_state:
|
if 'video_subject' not in st.session_state:
|
||||||
st.session_state['video_subject'] = ''
|
st.session_state['video_subject'] = ''
|
||||||
@@ -81,7 +53,7 @@ if 'video_script' not in st.session_state:
|
|||||||
if 'video_terms' not in st.session_state:
|
if 'video_terms' not in st.session_state:
|
||||||
st.session_state['video_terms'] = ''
|
st.session_state['video_terms'] = ''
|
||||||
if 'ui_language' not in st.session_state:
|
if 'ui_language' not in st.session_state:
|
||||||
st.session_state['ui_language'] = cfg.get("ui_language", get_system_locale())
|
st.session_state['ui_language'] = config.ui.get("language", system_locale)
|
||||||
|
|
||||||
|
|
||||||
def get_all_fonts():
|
def get_all_fonts():
|
||||||
@@ -163,19 +135,7 @@ def init_log():
|
|||||||
|
|
||||||
init_log()
|
init_log()
|
||||||
|
|
||||||
|
locales = utils.load_locales(i18n_dir)
|
||||||
def load_locales():
|
|
||||||
locales = {}
|
|
||||||
for root, dirs, files in os.walk(i18n_dir):
|
|
||||||
for file in files:
|
|
||||||
if file.endswith(".json"):
|
|
||||||
lang = file.split(".")[0]
|
|
||||||
with open(os.path.join(root, file), "r", encoding="utf-8") as f:
|
|
||||||
locales[lang] = json.loads(f.read())
|
|
||||||
return locales
|
|
||||||
|
|
||||||
|
|
||||||
locales = load_locales()
|
|
||||||
|
|
||||||
|
|
||||||
def tr(key):
|
def tr(key):
|
||||||
@@ -183,20 +143,76 @@ def tr(key):
|
|||||||
return loc.get("Translation", {}).get(key, key)
|
return loc.get("Translation", {}).get(key, key)
|
||||||
|
|
||||||
|
|
||||||
display_languages = []
|
st.write(tr("Get Help"))
|
||||||
selected_index = 0
|
|
||||||
for i, code in enumerate(locales.keys()):
|
with st.expander(tr("Basic Settings"), expanded=False):
|
||||||
|
config_panels = st.columns(3)
|
||||||
|
left_config_panel = config_panels[0]
|
||||||
|
middle_config_panel = config_panels[1]
|
||||||
|
right_config_panel = config_panels[2]
|
||||||
|
with left_config_panel:
|
||||||
|
display_languages = []
|
||||||
|
selected_index = 0
|
||||||
|
for i, code in enumerate(locales.keys()):
|
||||||
display_languages.append(f"{code} - {locales[code].get('Language')}")
|
display_languages.append(f"{code} - {locales[code].get('Language')}")
|
||||||
if code == st.session_state['ui_language']:
|
if code == st.session_state['ui_language']:
|
||||||
selected_index = i
|
selected_index = i
|
||||||
|
|
||||||
selected_language = st.selectbox("Language", options=display_languages, label_visibility='collapsed',
|
selected_language = st.selectbox(tr("Language"), options=display_languages,
|
||||||
index=selected_index)
|
index=selected_index)
|
||||||
if selected_language:
|
if selected_language:
|
||||||
code = selected_language.split(" - ")[0].strip()
|
code = selected_language.split(" - ")[0].strip()
|
||||||
st.session_state['ui_language'] = code
|
st.session_state['ui_language'] = code
|
||||||
cfg['ui_language'] = code
|
config.ui['language'] = code
|
||||||
save_config()
|
config.save_config()
|
||||||
|
|
||||||
|
with middle_config_panel:
|
||||||
|
# openai
|
||||||
|
# moonshot (月之暗面)
|
||||||
|
# oneapi
|
||||||
|
# g4f
|
||||||
|
# azure
|
||||||
|
# qwen (通义千问)
|
||||||
|
# gemini
|
||||||
|
# ollama
|
||||||
|
llm_providers = ['OpenAI', 'Moonshot', 'Azure', 'Qwen', 'Gemini', 'Ollama', 'G4f', 'OneAPI']
|
||||||
|
saved_llm_provider = config.app.get("llm_provider", "OpenAI").lower()
|
||||||
|
saved_llm_provider_index = 0
|
||||||
|
for i, provider in enumerate(llm_providers):
|
||||||
|
if provider.lower() == saved_llm_provider:
|
||||||
|
saved_llm_provider_index = i
|
||||||
|
break
|
||||||
|
|
||||||
|
llm_provider = st.selectbox(tr("LLM Provider"), options=llm_providers, index=saved_llm_provider_index)
|
||||||
|
llm_provider = llm_provider.lower()
|
||||||
|
config.app["llm_provider"] = llm_provider
|
||||||
|
|
||||||
|
llm_api_key = config.app.get(f"{llm_provider}_api_key", "")
|
||||||
|
llm_base_url = config.app.get(f"{llm_provider}_base_url", "")
|
||||||
|
llm_model_name = config.app.get(f"{llm_provider}_model_name", "")
|
||||||
|
st_llm_api_key = st.text_input(tr("API Key"), value=llm_api_key, type="password")
|
||||||
|
st_llm_base_url = st.text_input(tr("Base Url"), value=llm_base_url)
|
||||||
|
st_llm_model_name = st.text_input(tr("Model Name"), value=llm_model_name)
|
||||||
|
if st_llm_api_key:
|
||||||
|
config.app[f"{llm_provider}_api_key"] = st_llm_api_key
|
||||||
|
if st_llm_base_url:
|
||||||
|
config.app[f"{llm_provider}_base_url"] = st_llm_base_url
|
||||||
|
if st_llm_model_name:
|
||||||
|
config.app[f"{llm_provider}_model_name"] = st_llm_model_name
|
||||||
|
|
||||||
|
config.save_config()
|
||||||
|
|
||||||
|
with right_config_panel:
|
||||||
|
pexels_api_keys = config.app.get("pexels_api_keys", [])
|
||||||
|
if isinstance(pexels_api_keys, str):
|
||||||
|
pexels_api_keys = [pexels_api_keys]
|
||||||
|
pexels_api_key = ", ".join(pexels_api_keys)
|
||||||
|
|
||||||
|
pexels_api_key = st.text_input(tr("Pexels API Key"), value=pexels_api_key, type="password")
|
||||||
|
pexels_api_key = pexels_api_key.replace(" ", "")
|
||||||
|
if pexels_api_key:
|
||||||
|
config.app["pexels_api_keys"] = pexels_api_key.split(",")
|
||||||
|
config.save_config()
|
||||||
|
|
||||||
panel = st.columns(3)
|
panel = st.columns(3)
|
||||||
left_panel = panel[0]
|
left_panel = panel[0]
|
||||||
@@ -286,7 +302,7 @@ with middle_panel:
|
|||||||
replace("Male", tr("Male")).
|
replace("Male", tr("Male")).
|
||||||
replace("Neural", "") for
|
replace("Neural", "") for
|
||||||
voice in voices}
|
voice in voices}
|
||||||
saved_voice_name = cfg.get("voice_name", "")
|
saved_voice_name = config.ui.get("voice_name", "")
|
||||||
saved_voice_name_index = 0
|
saved_voice_name_index = 0
|
||||||
if saved_voice_name in friendly_names:
|
if saved_voice_name in friendly_names:
|
||||||
saved_voice_name_index = list(friendly_names.keys()).index(saved_voice_name)
|
saved_voice_name_index = list(friendly_names.keys()).index(saved_voice_name)
|
||||||
@@ -302,8 +318,8 @@ with middle_panel:
|
|||||||
|
|
||||||
voice_name = list(friendly_names.keys())[list(friendly_names.values()).index(selected_friendly_name)]
|
voice_name = list(friendly_names.keys())[list(friendly_names.values()).index(selected_friendly_name)]
|
||||||
params.voice_name = voice_name
|
params.voice_name = voice_name
|
||||||
cfg['voice_name'] = voice_name
|
config.ui['voice_name'] = voice_name
|
||||||
save_config()
|
config.save_config()
|
||||||
|
|
||||||
params.voice_volume = st.selectbox(tr("Speech Volume"),
|
params.voice_volume = st.selectbox(tr("Speech Volume"),
|
||||||
options=[0.6, 0.8, 1.0, 1.2, 1.5, 2.0, 3.0, 4.0, 5.0], index=2)
|
options=[0.6, 0.8, 1.0, 1.2, 1.5, 2.0, 3.0, 4.0, 5.0], index=2)
|
||||||
@@ -334,7 +350,13 @@ with right_panel:
|
|||||||
st.write(tr("Subtitle Settings"))
|
st.write(tr("Subtitle Settings"))
|
||||||
params.subtitle_enabled = st.checkbox(tr("Enable Subtitles"), value=True)
|
params.subtitle_enabled = st.checkbox(tr("Enable Subtitles"), value=True)
|
||||||
font_names = get_all_fonts()
|
font_names = get_all_fonts()
|
||||||
params.font_name = st.selectbox(tr("Font"), font_names)
|
saved_font_name = config.ui.get("font_name", "")
|
||||||
|
saved_font_name_index = 0
|
||||||
|
if saved_font_name in font_names:
|
||||||
|
saved_font_name_index = font_names.index(saved_font_name)
|
||||||
|
params.font_name = st.selectbox(tr("Font"), font_names, index=saved_font_name_index)
|
||||||
|
config.ui['font_name'] = params.font_name
|
||||||
|
config.save_config()
|
||||||
|
|
||||||
subtitle_positions = [
|
subtitle_positions = [
|
||||||
(tr("Top"), "top"),
|
(tr("Top"), "top"),
|
||||||
@@ -350,9 +372,14 @@ with right_panel:
|
|||||||
|
|
||||||
font_cols = st.columns([0.3, 0.7])
|
font_cols = st.columns([0.3, 0.7])
|
||||||
with font_cols[0]:
|
with font_cols[0]:
|
||||||
params.text_fore_color = st.color_picker(tr("Font Color"), "#FFFFFF")
|
saved_text_fore_color = config.ui.get("text_fore_color", "#FFFFFF")
|
||||||
|
params.text_fore_color = st.color_picker(tr("Font Color"), saved_text_fore_color)
|
||||||
|
config.ui['text_fore_color'] = params.text_fore_color
|
||||||
|
|
||||||
with font_cols[1]:
|
with font_cols[1]:
|
||||||
params.font_size = st.slider(tr("Font Size"), 30, 100, 60)
|
saved_font_size = config.ui.get("font_size", 60)
|
||||||
|
params.font_size = st.slider(tr("Font Size"), 30, 100, saved_font_size)
|
||||||
|
config.ui['font_size'] = params.font_size
|
||||||
|
|
||||||
stroke_cols = st.columns([0.3, 0.7])
|
stroke_cols = st.columns([0.3, 0.7])
|
||||||
with stroke_cols[0]:
|
with stroke_cols[0]:
|
||||||
@@ -362,12 +389,23 @@ with right_panel:
|
|||||||
|
|
||||||
start_button = st.button(tr("Generate Video"), use_container_width=True, type="primary")
|
start_button = st.button(tr("Generate Video"), use_container_width=True, type="primary")
|
||||||
if start_button:
|
if start_button:
|
||||||
|
config.save_config()
|
||||||
task_id = str(uuid4())
|
task_id = str(uuid4())
|
||||||
if not params.video_subject and not params.video_script:
|
if not params.video_subject and not params.video_script:
|
||||||
st.error(tr("Video Script and Subject Cannot Both Be Empty"))
|
st.error(tr("Video Script and Subject Cannot Both Be Empty"))
|
||||||
scroll_to_bottom()
|
scroll_to_bottom()
|
||||||
st.stop()
|
st.stop()
|
||||||
|
|
||||||
|
if not config.app.get(f"{llm_provider}_api_key", ""):
|
||||||
|
st.error(tr("Please Enter the LLM API Key"))
|
||||||
|
scroll_to_bottom()
|
||||||
|
st.stop()
|
||||||
|
|
||||||
|
if not config.app.get("pexels_api_keys", ""):
|
||||||
|
st.error(tr("Please Enter the Pexels API Key"))
|
||||||
|
scroll_to_bottom()
|
||||||
|
st.stop()
|
||||||
|
|
||||||
log_container = st.empty()
|
log_container = st.empty()
|
||||||
log_records = []
|
log_records = []
|
||||||
|
|
||||||
|
|||||||
@@ -48,6 +48,16 @@
|
|||||||
"Generating Video": "Video wird erstellt, bitte warten...",
|
"Generating Video": "Video wird erstellt, bitte warten...",
|
||||||
"Start Generating Video": "Beginne mit der Generierung",
|
"Start Generating Video": "Beginne mit der Generierung",
|
||||||
"Video Generation Completed": "Video erfolgreich generiert",
|
"Video Generation Completed": "Video erfolgreich generiert",
|
||||||
"You can download the generated video from the following links": "Sie können das generierte Video über die folgenden Links herunterladen"
|
"You can download the generated video from the following links": "Sie können das generierte Video über die folgenden Links herunterladen",
|
||||||
|
"Basic Settings": "**Grunde Instellungen**",
|
||||||
|
"Pexels API Key": "Pexels API Key (:red[Required] [Get API Key](https://www.pexels.com/api/))",
|
||||||
|
"Language": "Language",
|
||||||
|
"LLM Provider": "LLM Provider",
|
||||||
|
"API Key": "API Key (:red[Required])",
|
||||||
|
"Base Url": "Base Url",
|
||||||
|
"Model Name": "Model Name",
|
||||||
|
"Please Enter the LLM API Key": "Please Enter the **LLM API Key**",
|
||||||
|
"Please Enter the Pexels API Key": "Please Enter the **Pexels API Key**",
|
||||||
|
"Get Help": "If you need help, or have any questions, you can join discord for help: https://harryai.cc/moneyprinterturbo"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -48,6 +48,16 @@
|
|||||||
"Generating Video": "Generating video, please wait...",
|
"Generating Video": "Generating video, please wait...",
|
||||||
"Start Generating Video": "Start Generating Video",
|
"Start Generating Video": "Start Generating Video",
|
||||||
"Video Generation Completed": "Video Generation Completed",
|
"Video Generation Completed": "Video Generation Completed",
|
||||||
"You can download the generated video from the following links": "You can download the generated video from the following links"
|
"You can download the generated video from the following links": "You can download the generated video from the following links",
|
||||||
|
"Pexels API Key": "Pexels API Key (:red[Required] [Get API Key](https://www.pexels.com/api/))",
|
||||||
|
"Basic Settings": "**Basic Settings** (:blue[Click to expand])",
|
||||||
|
"Language": "Language",
|
||||||
|
"LLM Provider": "LLM Provider",
|
||||||
|
"API Key": "API Key (:red[Required])",
|
||||||
|
"Base Url": "Base Url",
|
||||||
|
"Model Name": "Model Name",
|
||||||
|
"Please Enter the LLM API Key": "Please Enter the **LLM API Key**",
|
||||||
|
"Please Enter the Pexels API Key": "Please Enter the **Pexels API Key**",
|
||||||
|
"Get Help": "If you need help, or have any questions, you can join discord for help: https://harryai.cc/moneyprinterturbo"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -48,6 +48,16 @@
|
|||||||
"Generating Video": "正在生成视频,请稍候...",
|
"Generating Video": "正在生成视频,请稍候...",
|
||||||
"Start Generating Video": "开始生成视频",
|
"Start Generating Video": "开始生成视频",
|
||||||
"Video Generation Completed": "视频生成完成",
|
"Video Generation Completed": "视频生成完成",
|
||||||
"You can download the generated video from the following links": "你可以从以下链接下载生成的视频"
|
"You can download the generated video from the following links": "你可以从以下链接下载生成的视频",
|
||||||
|
"Basic Settings": "**基础设置** (:blue[点击展开])",
|
||||||
|
"Language": "界面语言",
|
||||||
|
"Pexels API Key": "Pexels API Key (:red[必填] [点击获取](https://www.pexels.com/api/))",
|
||||||
|
"LLM Provider": "大模型提供商",
|
||||||
|
"API Key": "API Key (:red[必填,需要到大模型提供商的后台申请])",
|
||||||
|
"Base Url": "Base Url (可选)",
|
||||||
|
"Model Name": "模型名称 (:blue[需要到大模型提供商的后台确认被授权的模型名称])",
|
||||||
|
"Please Enter the LLM API Key": "请先填写大模型 **API Key**",
|
||||||
|
"Please Enter the Pexels API Key": "请先填写 **Pexels API Key**",
|
||||||
|
"Get Help": "有任何问题或建议,可以加入 **微信群** 求助或讨论:https://harryai.cc/moneyprinterturbo"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user