Format project code
This commit is contained in:
@@ -21,6 +21,7 @@ def _generate_response(prompt: str) -> str:
|
||||
if not model_name:
|
||||
model_name = "gpt-3.5-turbo-16k-0613"
|
||||
import g4f
|
||||
|
||||
content = g4f.ChatCompletion.create(
|
||||
model=model_name,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
@@ -78,44 +79,56 @@ def _generate_response(prompt: str) -> str:
|
||||
base_url = config.app.get("ernie_base_url")
|
||||
model_name = "***"
|
||||
if not secret_key:
|
||||
raise ValueError(f"{llm_provider}: secret_key is not set, please set it in the config.toml file.")
|
||||
raise ValueError(
|
||||
f"{llm_provider}: secret_key is not set, please set it in the config.toml file."
|
||||
)
|
||||
else:
|
||||
raise ValueError("llm_provider is not set, please set it in the config.toml file.")
|
||||
raise ValueError(
|
||||
"llm_provider is not set, please set it in the config.toml file."
|
||||
)
|
||||
|
||||
if not api_key:
|
||||
raise ValueError(f"{llm_provider}: api_key is not set, please set it in the config.toml file.")
|
||||
raise ValueError(
|
||||
f"{llm_provider}: api_key is not set, please set it in the config.toml file."
|
||||
)
|
||||
if not model_name:
|
||||
raise ValueError(f"{llm_provider}: model_name is not set, please set it in the config.toml file.")
|
||||
raise ValueError(
|
||||
f"{llm_provider}: model_name is not set, please set it in the config.toml file."
|
||||
)
|
||||
if not base_url:
|
||||
raise ValueError(f"{llm_provider}: base_url is not set, please set it in the config.toml file.")
|
||||
raise ValueError(
|
||||
f"{llm_provider}: base_url is not set, please set it in the config.toml file."
|
||||
)
|
||||
|
||||
if llm_provider == "qwen":
|
||||
import dashscope
|
||||
from dashscope.api_entities.dashscope_response import GenerationResponse
|
||||
|
||||
dashscope.api_key = api_key
|
||||
response = dashscope.Generation.call(
|
||||
model=model_name,
|
||||
messages=[{"role": "user", "content": prompt}]
|
||||
model=model_name, messages=[{"role": "user", "content": prompt}]
|
||||
)
|
||||
if response:
|
||||
if isinstance(response, GenerationResponse):
|
||||
status_code = response.status_code
|
||||
if status_code != 200:
|
||||
raise Exception(
|
||||
f"[{llm_provider}] returned an error response: \"{response}\"")
|
||||
f'[{llm_provider}] returned an error response: "{response}"'
|
||||
)
|
||||
|
||||
content = response["output"]["text"]
|
||||
return content.replace("\n", "")
|
||||
else:
|
||||
raise Exception(
|
||||
f"[{llm_provider}] returned an invalid response: \"{response}\"")
|
||||
f'[{llm_provider}] returned an invalid response: "{response}"'
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
f"[{llm_provider}] returned an empty response")
|
||||
raise Exception(f"[{llm_provider}] returned an empty response")
|
||||
|
||||
if llm_provider == "gemini":
|
||||
import google.generativeai as genai
|
||||
genai.configure(api_key=api_key, transport='rest')
|
||||
|
||||
genai.configure(api_key=api_key, transport="rest")
|
||||
|
||||
generation_config = {
|
||||
"temperature": 0.5,
|
||||
@@ -127,25 +140,27 @@ def _generate_response(prompt: str) -> str:
|
||||
safety_settings = [
|
||||
{
|
||||
"category": "HARM_CATEGORY_HARASSMENT",
|
||||
"threshold": "BLOCK_ONLY_HIGH"
|
||||
"threshold": "BLOCK_ONLY_HIGH",
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||
"threshold": "BLOCK_ONLY_HIGH"
|
||||
"threshold": "BLOCK_ONLY_HIGH",
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"threshold": "BLOCK_ONLY_HIGH"
|
||||
"threshold": "BLOCK_ONLY_HIGH",
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"threshold": "BLOCK_ONLY_HIGH"
|
||||
"threshold": "BLOCK_ONLY_HIGH",
|
||||
},
|
||||
]
|
||||
|
||||
model = genai.GenerativeModel(model_name=model_name,
|
||||
generation_config=generation_config,
|
||||
safety_settings=safety_settings)
|
||||
model = genai.GenerativeModel(
|
||||
model_name=model_name,
|
||||
generation_config=generation_config,
|
||||
safety_settings=safety_settings,
|
||||
)
|
||||
|
||||
try:
|
||||
response = model.generate_content(prompt)
|
||||
@@ -158,15 +173,16 @@ def _generate_response(prompt: str) -> str:
|
||||
|
||||
if llm_provider == "cloudflare":
|
||||
import requests
|
||||
|
||||
response = requests.post(
|
||||
f"https://api.cloudflare.com/client/v4/accounts/{account_id}/ai/run/{model_name}",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
json={
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a friendly assistant"},
|
||||
{"role": "user", "content": prompt}
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
}
|
||||
},
|
||||
)
|
||||
result = response.json()
|
||||
logger.info(result)
|
||||
@@ -174,30 +190,35 @@ def _generate_response(prompt: str) -> str:
|
||||
|
||||
if llm_provider == "ernie":
|
||||
import requests
|
||||
params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key}
|
||||
access_token = requests.post("https://aip.baidubce.com/oauth/2.0/token", params=params).json().get(
|
||||
"access_token")
|
||||
|
||||
params = {
|
||||
"grant_type": "client_credentials",
|
||||
"client_id": api_key,
|
||||
"client_secret": secret_key,
|
||||
}
|
||||
access_token = (
|
||||
requests.post("https://aip.baidubce.com/oauth/2.0/token", params=params)
|
||||
.json()
|
||||
.get("access_token")
|
||||
)
|
||||
url = f"{base_url}?access_token={access_token}"
|
||||
|
||||
payload = json.dumps({
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}
|
||||
],
|
||||
"temperature": 0.5,
|
||||
"top_p": 0.8,
|
||||
"penalty_score": 1,
|
||||
"disable_search": False,
|
||||
"enable_citation": False,
|
||||
"response_format": "text"
|
||||
})
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
payload = json.dumps(
|
||||
{
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"temperature": 0.5,
|
||||
"top_p": 0.8,
|
||||
"penalty_score": 1,
|
||||
"disable_search": False,
|
||||
"enable_citation": False,
|
||||
"response_format": "text",
|
||||
}
|
||||
)
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
response = requests.request("POST", url, headers=headers, data=payload).json()
|
||||
response = requests.request(
|
||||
"POST", url, headers=headers, data=payload
|
||||
).json()
|
||||
return response.get("result")
|
||||
|
||||
if llm_provider == "azure":
|
||||
@@ -213,24 +234,27 @@ def _generate_response(prompt: str) -> str:
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=[{"role": "user", "content": prompt}]
|
||||
model=model_name, messages=[{"role": "user", "content": prompt}]
|
||||
)
|
||||
if response:
|
||||
if isinstance(response, ChatCompletion):
|
||||
content = response.choices[0].message.content
|
||||
else:
|
||||
raise Exception(
|
||||
f"[{llm_provider}] returned an invalid response: \"{response}\", please check your network "
|
||||
f"connection and try again.")
|
||||
f'[{llm_provider}] returned an invalid response: "{response}", please check your network '
|
||||
f"connection and try again."
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
f"[{llm_provider}] returned an empty response, please check your network connection and try again.")
|
||||
f"[{llm_provider}] returned an empty response, please check your network connection and try again."
|
||||
)
|
||||
|
||||
return content.replace("\n", "")
|
||||
|
||||
|
||||
def generate_script(video_subject: str, language: str = "", paragraph_number: int = 1) -> str:
|
||||
def generate_script(
|
||||
video_subject: str, language: str = "", paragraph_number: int = 1
|
||||
) -> str:
|
||||
prompt = f"""
|
||||
# Role: Video Script Generator
|
||||
|
||||
@@ -335,14 +359,16 @@ Please note that you must use English for generating video search terms; Chinese
|
||||
try:
|
||||
response = _generate_response(prompt)
|
||||
search_terms = json.loads(response)
|
||||
if not isinstance(search_terms, list) or not all(isinstance(term, str) for term in search_terms):
|
||||
if not isinstance(search_terms, list) or not all(
|
||||
isinstance(term, str) for term in search_terms
|
||||
):
|
||||
logger.error("response is not a list of strings.")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"failed to generate video terms: {str(e)}")
|
||||
if response:
|
||||
match = re.search(r'\[.*]', response)
|
||||
match = re.search(r"\[.*]", response)
|
||||
if match:
|
||||
try:
|
||||
search_terms = json.loads(match.group())
|
||||
@@ -361,9 +387,13 @@ Please note that you must use English for generating video search terms; Chinese
|
||||
|
||||
if __name__ == "__main__":
|
||||
video_subject = "生命的意义是什么"
|
||||
script = generate_script(video_subject=video_subject, language="zh-CN", paragraph_number=1)
|
||||
script = generate_script(
|
||||
video_subject=video_subject, language="zh-CN", paragraph_number=1
|
||||
)
|
||||
print("######################")
|
||||
print(script)
|
||||
search_terms = generate_terms(video_subject=video_subject, video_script=script, amount=5)
|
||||
search_terms = generate_terms(
|
||||
video_subject=video_subject, video_script=script, amount=5
|
||||
)
|
||||
print("######################")
|
||||
print(search_terms)
|
||||
|
||||
@@ -19,7 +19,8 @@ def get_api_key(cfg_key: str):
|
||||
if not api_keys:
|
||||
raise ValueError(
|
||||
f"\n\n##### {cfg_key} is not set #####\n\nPlease set it in the config.toml file: {config.config_file}\n\n"
|
||||
f"{utils.to_json(config.app)}")
|
||||
f"{utils.to_json(config.app)}"
|
||||
)
|
||||
|
||||
# if only one key is provided, return it
|
||||
if isinstance(api_keys, str):
|
||||
@@ -30,28 +31,29 @@ def get_api_key(cfg_key: str):
|
||||
return api_keys[requested_count % len(api_keys)]
|
||||
|
||||
|
||||
def search_videos_pexels(search_term: str,
|
||||
minimum_duration: int,
|
||||
video_aspect: VideoAspect = VideoAspect.portrait,
|
||||
) -> List[MaterialInfo]:
|
||||
def search_videos_pexels(
|
||||
search_term: str,
|
||||
minimum_duration: int,
|
||||
video_aspect: VideoAspect = VideoAspect.portrait,
|
||||
) -> List[MaterialInfo]:
|
||||
aspect = VideoAspect(video_aspect)
|
||||
video_orientation = aspect.name
|
||||
video_width, video_height = aspect.to_resolution()
|
||||
api_key = get_api_key("pexels_api_keys")
|
||||
headers = {
|
||||
"Authorization": api_key
|
||||
}
|
||||
headers = {"Authorization": api_key}
|
||||
# Build URL
|
||||
params = {
|
||||
"query": search_term,
|
||||
"per_page": 20,
|
||||
"orientation": video_orientation
|
||||
}
|
||||
params = {"query": search_term, "per_page": 20, "orientation": video_orientation}
|
||||
query_url = f"https://api.pexels.com/videos/search?{urlencode(params)}"
|
||||
logger.info(f"searching videos: {query_url}, with proxies: {config.proxy}")
|
||||
|
||||
try:
|
||||
r = requests.get(query_url, headers=headers, proxies=config.proxy, verify=False, timeout=(30, 60))
|
||||
r = requests.get(
|
||||
query_url,
|
||||
headers=headers,
|
||||
proxies=config.proxy,
|
||||
verify=False,
|
||||
timeout=(30, 60),
|
||||
)
|
||||
response = r.json()
|
||||
video_items = []
|
||||
if "videos" not in response:
|
||||
@@ -83,10 +85,11 @@ def search_videos_pexels(search_term: str,
|
||||
return []
|
||||
|
||||
|
||||
def search_videos_pixabay(search_term: str,
|
||||
minimum_duration: int,
|
||||
video_aspect: VideoAspect = VideoAspect.portrait,
|
||||
) -> List[MaterialInfo]:
|
||||
def search_videos_pixabay(
|
||||
search_term: str,
|
||||
minimum_duration: int,
|
||||
video_aspect: VideoAspect = VideoAspect.portrait,
|
||||
) -> List[MaterialInfo]:
|
||||
aspect = VideoAspect(video_aspect)
|
||||
|
||||
video_width, video_height = aspect.to_resolution()
|
||||
@@ -97,13 +100,15 @@ def search_videos_pixabay(search_term: str,
|
||||
"q": search_term,
|
||||
"video_type": "all", # Accepted values: "all", "film", "animation"
|
||||
"per_page": 50,
|
||||
"key": api_key
|
||||
"key": api_key,
|
||||
}
|
||||
query_url = f"https://pixabay.com/api/videos/?{urlencode(params)}"
|
||||
logger.info(f"searching videos: {query_url}, with proxies: {config.proxy}")
|
||||
|
||||
try:
|
||||
r = requests.get(query_url, proxies=config.proxy, verify=False, timeout=(30, 60))
|
||||
r = requests.get(
|
||||
query_url, proxies=config.proxy, verify=False, timeout=(30, 60)
|
||||
)
|
||||
response = r.json()
|
||||
video_items = []
|
||||
if "hits" not in response:
|
||||
@@ -155,7 +160,11 @@ def save_video(video_url: str, save_dir: str = "") -> str:
|
||||
|
||||
# if video does not exist, download it
|
||||
with open(video_path, "wb") as f:
|
||||
f.write(requests.get(video_url, proxies=config.proxy, verify=False, timeout=(60, 240)).content)
|
||||
f.write(
|
||||
requests.get(
|
||||
video_url, proxies=config.proxy, verify=False, timeout=(60, 240)
|
||||
).content
|
||||
)
|
||||
|
||||
if os.path.exists(video_path) and os.path.getsize(video_path) > 0:
|
||||
try:
|
||||
@@ -174,14 +183,15 @@ def save_video(video_url: str, save_dir: str = "") -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def download_videos(task_id: str,
|
||||
search_terms: List[str],
|
||||
source: str = "pexels",
|
||||
video_aspect: VideoAspect = VideoAspect.portrait,
|
||||
video_contact_mode: VideoConcatMode = VideoConcatMode.random,
|
||||
audio_duration: float = 0.0,
|
||||
max_clip_duration: int = 5,
|
||||
) -> List[str]:
|
||||
def download_videos(
|
||||
task_id: str,
|
||||
search_terms: List[str],
|
||||
source: str = "pexels",
|
||||
video_aspect: VideoAspect = VideoAspect.portrait,
|
||||
video_contact_mode: VideoConcatMode = VideoConcatMode.random,
|
||||
audio_duration: float = 0.0,
|
||||
max_clip_duration: int = 5,
|
||||
) -> List[str]:
|
||||
valid_video_items = []
|
||||
valid_video_urls = []
|
||||
found_duration = 0.0
|
||||
@@ -190,9 +200,11 @@ def download_videos(task_id: str,
|
||||
search_videos = search_videos_pixabay
|
||||
|
||||
for search_term in search_terms:
|
||||
video_items = search_videos(search_term=search_term,
|
||||
minimum_duration=max_clip_duration,
|
||||
video_aspect=video_aspect)
|
||||
video_items = search_videos(
|
||||
search_term=search_term,
|
||||
minimum_duration=max_clip_duration,
|
||||
video_aspect=video_aspect,
|
||||
)
|
||||
logger.info(f"found {len(video_items)} videos for '{search_term}'")
|
||||
|
||||
for item in video_items:
|
||||
@@ -202,7 +214,8 @@ def download_videos(task_id: str,
|
||||
found_duration += item.duration
|
||||
|
||||
logger.info(
|
||||
f"found total videos: {len(valid_video_items)}, required duration: {audio_duration} seconds, found duration: {found_duration} seconds")
|
||||
f"found total videos: {len(valid_video_items)}, required duration: {audio_duration} seconds, found duration: {found_duration} seconds"
|
||||
)
|
||||
video_paths = []
|
||||
|
||||
material_directory = config.app.get("material_directory", "").strip()
|
||||
@@ -218,14 +231,18 @@ def download_videos(task_id: str,
|
||||
for item in valid_video_items:
|
||||
try:
|
||||
logger.info(f"downloading video: {item.url}")
|
||||
saved_video_path = save_video(video_url=item.url, save_dir=material_directory)
|
||||
saved_video_path = save_video(
|
||||
video_url=item.url, save_dir=material_directory
|
||||
)
|
||||
if saved_video_path:
|
||||
logger.info(f"video saved: {saved_video_path}")
|
||||
video_paths.append(saved_video_path)
|
||||
seconds = min(max_clip_duration, item.duration)
|
||||
total_duration += seconds
|
||||
if total_duration > audio_duration:
|
||||
logger.info(f"total duration of downloaded videos: {total_duration} seconds, skip downloading more")
|
||||
logger.info(
|
||||
f"total duration of downloaded videos: {total_duration} seconds, skip downloading more"
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"failed to download video: {utils.to_json(item)} => {str(e)}")
|
||||
@@ -234,4 +251,6 @@ def download_videos(task_id: str,
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
download_videos("test123", ["Money Exchange Medium"], audio_duration=100, source="pixabay")
|
||||
download_videos(
|
||||
"test123", ["Money Exchange Medium"], audio_duration=100, source="pixabay"
|
||||
)
|
||||
|
||||
@@ -6,7 +6,6 @@ from app.models import const
|
||||
|
||||
# Base class for state management
|
||||
class BaseState(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def update_task(self, task_id: str, state: int, progress: int = 0, **kwargs):
|
||||
pass
|
||||
@@ -18,11 +17,16 @@ class BaseState(ABC):
|
||||
|
||||
# Memory state management
|
||||
class MemoryState(BaseState):
|
||||
|
||||
def __init__(self):
|
||||
self._tasks = {}
|
||||
|
||||
def update_task(self, task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs):
|
||||
def update_task(
|
||||
self,
|
||||
task_id: str,
|
||||
state: int = const.TASK_STATE_PROCESSING,
|
||||
progress: int = 0,
|
||||
**kwargs,
|
||||
):
|
||||
progress = int(progress)
|
||||
if progress > 100:
|
||||
progress = 100
|
||||
@@ -43,12 +47,18 @@ class MemoryState(BaseState):
|
||||
|
||||
# Redis state management
|
||||
class RedisState(BaseState):
|
||||
|
||||
def __init__(self, host='localhost', port=6379, db=0, password=None):
|
||||
def __init__(self, host="localhost", port=6379, db=0, password=None):
|
||||
import redis
|
||||
|
||||
self._redis = redis.StrictRedis(host=host, port=port, db=db, password=password)
|
||||
|
||||
def update_task(self, task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs):
|
||||
def update_task(
|
||||
self,
|
||||
task_id: str,
|
||||
state: int = const.TASK_STATE_PROCESSING,
|
||||
progress: int = 0,
|
||||
**kwargs,
|
||||
):
|
||||
progress = int(progress)
|
||||
if progress > 100:
|
||||
progress = 100
|
||||
@@ -67,7 +77,10 @@ class RedisState(BaseState):
|
||||
if not task_data:
|
||||
return None
|
||||
|
||||
task = {key.decode('utf-8'): self._convert_to_original_type(value) for key, value in task_data.items()}
|
||||
task = {
|
||||
key.decode("utf-8"): self._convert_to_original_type(value)
|
||||
for key, value in task_data.items()
|
||||
}
|
||||
return task
|
||||
|
||||
def delete_task(self, task_id: str):
|
||||
@@ -79,7 +92,7 @@ class RedisState(BaseState):
|
||||
Convert the value from byte string to its original data type.
|
||||
You can extend this method to handle other data types as needed.
|
||||
"""
|
||||
value_str = value.decode('utf-8')
|
||||
value_str = value.decode("utf-8")
|
||||
|
||||
try:
|
||||
# try to convert byte string array to list
|
||||
@@ -100,4 +113,10 @@ _redis_port = config.app.get("redis_port", 6379)
|
||||
_redis_db = config.app.get("redis_db", 0)
|
||||
_redis_password = config.app.get("redis_password", None)
|
||||
|
||||
state = RedisState(host=_redis_host, port=_redis_port, db=_redis_db, password=_redis_password) if _enable_redis else MemoryState()
|
||||
state = (
|
||||
RedisState(
|
||||
host=_redis_host, port=_redis_port, db=_redis_db, password=_redis_password
|
||||
)
|
||||
if _enable_redis
|
||||
else MemoryState()
|
||||
)
|
||||
|
||||
@@ -23,18 +23,22 @@ def create(audio_file, subtitle_file: str = ""):
|
||||
if not os.path.isdir(model_path) or not os.path.isfile(model_bin_file):
|
||||
model_path = model_size
|
||||
|
||||
logger.info(f"loading model: {model_path}, device: {device}, compute_type: {compute_type}")
|
||||
logger.info(
|
||||
f"loading model: {model_path}, device: {device}, compute_type: {compute_type}"
|
||||
)
|
||||
try:
|
||||
model = WhisperModel(model_size_or_path=model_path,
|
||||
device=device,
|
||||
compute_type=compute_type)
|
||||
model = WhisperModel(
|
||||
model_size_or_path=model_path, device=device, compute_type=compute_type
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"failed to load model: {e} \n\n"
|
||||
f"********************************************\n"
|
||||
f"this may be caused by network issue. \n"
|
||||
f"please download the model manually and put it in the 'models' folder. \n"
|
||||
f"see [README.md FAQ](https://github.com/harry0703/MoneyPrinterTurbo) for more details.\n"
|
||||
f"********************************************\n\n")
|
||||
logger.error(
|
||||
f"failed to load model: {e} \n\n"
|
||||
f"********************************************\n"
|
||||
f"this may be caused by network issue. \n"
|
||||
f"please download the model manually and put it in the 'models' folder. \n"
|
||||
f"see [README.md FAQ](https://github.com/harry0703/MoneyPrinterTurbo) for more details.\n"
|
||||
f"********************************************\n\n"
|
||||
)
|
||||
return None
|
||||
|
||||
logger.info(f"start, output file: {subtitle_file}")
|
||||
@@ -49,7 +53,9 @@ def create(audio_file, subtitle_file: str = ""):
|
||||
vad_parameters=dict(min_silence_duration_ms=500),
|
||||
)
|
||||
|
||||
logger.info(f"detected language: '{info.language}', probability: {info.language_probability:.2f}")
|
||||
logger.info(
|
||||
f"detected language: '{info.language}', probability: {info.language_probability:.2f}"
|
||||
)
|
||||
|
||||
start = timer()
|
||||
subtitles = []
|
||||
@@ -62,11 +68,9 @@ def create(audio_file, subtitle_file: str = ""):
|
||||
msg = "[%.2fs -> %.2fs] %s" % (seg_start, seg_end, seg_text)
|
||||
logger.debug(msg)
|
||||
|
||||
subtitles.append({
|
||||
"msg": seg_text,
|
||||
"start_time": seg_start,
|
||||
"end_time": seg_end
|
||||
})
|
||||
subtitles.append(
|
||||
{"msg": seg_text, "start_time": seg_start, "end_time": seg_end}
|
||||
)
|
||||
|
||||
for segment in segments:
|
||||
words_idx = 0
|
||||
@@ -119,7 +123,11 @@ def create(audio_file, subtitle_file: str = ""):
|
||||
for subtitle in subtitles:
|
||||
text = subtitle.get("msg")
|
||||
if text:
|
||||
lines.append(utils.text_to_srt(idx, text, subtitle.get("start_time"), subtitle.get("end_time")))
|
||||
lines.append(
|
||||
utils.text_to_srt(
|
||||
idx, text, subtitle.get("start_time"), subtitle.get("end_time")
|
||||
)
|
||||
)
|
||||
idx += 1
|
||||
|
||||
sub = "\n".join(lines) + "\n"
|
||||
@@ -136,12 +144,12 @@ def file_to_subtitles(filename):
|
||||
current_times = None
|
||||
current_text = ""
|
||||
index = 0
|
||||
with open(filename, 'r', encoding="utf-8") as f:
|
||||
with open(filename, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
times = re.findall("([0-9]*:[0-9]*:[0-9]*,[0-9]*)", line)
|
||||
if times:
|
||||
current_times = line
|
||||
elif line.strip() == '' and current_times:
|
||||
elif line.strip() == "" and current_times:
|
||||
index += 1
|
||||
times_texts.append((index, current_times.strip(), current_text.strip()))
|
||||
current_times, current_text = None, ""
|
||||
@@ -166,9 +174,10 @@ def levenshtein_distance(s1, s2):
|
||||
substitutions = previous_row[j] + (c1 != c2)
|
||||
current_row.append(min(insertions, deletions, substitutions))
|
||||
previous_row = current_row
|
||||
|
||||
|
||||
return previous_row[-1]
|
||||
|
||||
|
||||
def similarity(a, b):
|
||||
distance = levenshtein_distance(a.lower(), b.lower())
|
||||
max_length = max(len(a), len(b))
|
||||
@@ -194,26 +203,44 @@ def correct(subtitle_file, video_script):
|
||||
subtitle_index += 1
|
||||
else:
|
||||
combined_subtitle = subtitle_line
|
||||
start_time = subtitle_items[subtitle_index][1].split(' --> ')[0]
|
||||
end_time = subtitle_items[subtitle_index][1].split(' --> ')[1]
|
||||
start_time = subtitle_items[subtitle_index][1].split(" --> ")[0]
|
||||
end_time = subtitle_items[subtitle_index][1].split(" --> ")[1]
|
||||
next_subtitle_index = subtitle_index + 1
|
||||
|
||||
while next_subtitle_index < len(subtitle_items):
|
||||
next_subtitle = subtitle_items[next_subtitle_index][2].strip()
|
||||
if similarity(script_line, combined_subtitle + " " + next_subtitle) > similarity(script_line, combined_subtitle):
|
||||
if similarity(
|
||||
script_line, combined_subtitle + " " + next_subtitle
|
||||
) > similarity(script_line, combined_subtitle):
|
||||
combined_subtitle += " " + next_subtitle
|
||||
end_time = subtitle_items[next_subtitle_index][1].split(' --> ')[1]
|
||||
end_time = subtitle_items[next_subtitle_index][1].split(" --> ")[1]
|
||||
next_subtitle_index += 1
|
||||
else:
|
||||
break
|
||||
|
||||
if similarity(script_line, combined_subtitle) > 0.8:
|
||||
logger.warning(f"Merged/Corrected - Script: {script_line}, Subtitle: {combined_subtitle}")
|
||||
new_subtitle_items.append((len(new_subtitle_items) + 1, f"{start_time} --> {end_time}", script_line))
|
||||
logger.warning(
|
||||
f"Merged/Corrected - Script: {script_line}, Subtitle: {combined_subtitle}"
|
||||
)
|
||||
new_subtitle_items.append(
|
||||
(
|
||||
len(new_subtitle_items) + 1,
|
||||
f"{start_time} --> {end_time}",
|
||||
script_line,
|
||||
)
|
||||
)
|
||||
corrected = True
|
||||
else:
|
||||
logger.warning(f"Mismatch - Script: {script_line}, Subtitle: {combined_subtitle}")
|
||||
new_subtitle_items.append((len(new_subtitle_items) + 1, f"{start_time} --> {end_time}", script_line))
|
||||
logger.warning(
|
||||
f"Mismatch - Script: {script_line}, Subtitle: {combined_subtitle}"
|
||||
)
|
||||
new_subtitle_items.append(
|
||||
(
|
||||
len(new_subtitle_items) + 1,
|
||||
f"{start_time} --> {end_time}",
|
||||
script_line,
|
||||
)
|
||||
)
|
||||
corrected = True
|
||||
|
||||
script_index += 1
|
||||
@@ -223,10 +250,22 @@ def correct(subtitle_file, video_script):
|
||||
while script_index < len(script_lines):
|
||||
logger.warning(f"Extra script line: {script_lines[script_index]}")
|
||||
if subtitle_index < len(subtitle_items):
|
||||
new_subtitle_items.append((len(new_subtitle_items) + 1, subtitle_items[subtitle_index][1], script_lines[script_index]))
|
||||
new_subtitle_items.append(
|
||||
(
|
||||
len(new_subtitle_items) + 1,
|
||||
subtitle_items[subtitle_index][1],
|
||||
script_lines[script_index],
|
||||
)
|
||||
)
|
||||
subtitle_index += 1
|
||||
else:
|
||||
new_subtitle_items.append((len(new_subtitle_items) + 1, "00:00:00,000 --> 00:00:00,000", script_lines[script_index]))
|
||||
new_subtitle_items.append(
|
||||
(
|
||||
len(new_subtitle_items) + 1,
|
||||
"00:00:00,000 --> 00:00:00,000",
|
||||
script_lines[script_index],
|
||||
)
|
||||
)
|
||||
script_index += 1
|
||||
corrected = True
|
||||
|
||||
|
||||
@@ -988,7 +988,7 @@ Name: zh-CN-XiaoxiaoMultilingualNeural-V2
|
||||
Gender: Female
|
||||
""".strip()
|
||||
voices = []
|
||||
name = ''
|
||||
name = ""
|
||||
for line in voices_str.split("\n"):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
@@ -1008,7 +1008,7 @@ Gender: Female
|
||||
voices.append(f"{name}-{gender}")
|
||||
else:
|
||||
voices.append(f"{name}-{gender}")
|
||||
name = ''
|
||||
name = ""
|
||||
voices.sort()
|
||||
return voices
|
||||
|
||||
@@ -1028,7 +1028,9 @@ def is_azure_v2_voice(voice_name: str):
|
||||
return ""
|
||||
|
||||
|
||||
def tts(text: str, voice_name: str, voice_rate: float, voice_file: str) -> [SubMaker, None]:
|
||||
def tts(
|
||||
text: str, voice_name: str, voice_rate: float, voice_file: str
|
||||
) -> [SubMaker, None]:
|
||||
if is_azure_v2_voice(voice_name):
|
||||
return azure_tts_v2(text, voice_name, voice_file)
|
||||
return azure_tts_v1(text, voice_name, voice_rate, voice_file)
|
||||
@@ -1042,9 +1044,11 @@ def convert_rate_to_percent(rate: float) -> str:
|
||||
return f"+{percent}%"
|
||||
else:
|
||||
return f"{percent}%"
|
||||
|
||||
|
||||
def azure_tts_v1(text: str, voice_name: str, voice_rate: float, voice_file: str) -> [SubMaker, None]:
|
||||
|
||||
def azure_tts_v1(
|
||||
text: str, voice_name: str, voice_rate: float, voice_file: str
|
||||
) -> [SubMaker, None]:
|
||||
voice_name = parse_voice_name(voice_name)
|
||||
text = text.strip()
|
||||
rate_str = convert_rate_to_percent(voice_rate)
|
||||
@@ -1060,7 +1064,9 @@ def azure_tts_v1(text: str, voice_name: str, voice_rate: float, voice_file: str)
|
||||
if chunk["type"] == "audio":
|
||||
file.write(chunk["data"])
|
||||
elif chunk["type"] == "WordBoundary":
|
||||
sub_maker.create_sub((chunk["offset"], chunk["duration"]), chunk["text"])
|
||||
sub_maker.create_sub(
|
||||
(chunk["offset"], chunk["duration"]), chunk["text"]
|
||||
)
|
||||
return sub_maker
|
||||
|
||||
sub_maker = asyncio.run(_do())
|
||||
@@ -1085,8 +1091,12 @@ def azure_tts_v2(text: str, voice_name: str, voice_file: str) -> [SubMaker, None
|
||||
def _format_duration_to_offset(duration) -> int:
|
||||
if isinstance(duration, str):
|
||||
time_obj = datetime.strptime(duration, "%H:%M:%S.%f")
|
||||
milliseconds = (time_obj.hour * 3600000) + (time_obj.minute * 60000) + (time_obj.second * 1000) + (
|
||||
time_obj.microsecond // 1000)
|
||||
milliseconds = (
|
||||
(time_obj.hour * 3600000)
|
||||
+ (time_obj.minute * 60000)
|
||||
+ (time_obj.second * 1000)
|
||||
+ (time_obj.microsecond // 1000)
|
||||
)
|
||||
return milliseconds * 10000
|
||||
|
||||
if isinstance(duration, int):
|
||||
@@ -1119,20 +1129,29 @@ def azure_tts_v2(text: str, voice_name: str, voice_file: str) -> [SubMaker, None
|
||||
# Creates an instance of a speech config with specified subscription key and service region.
|
||||
speech_key = config.azure.get("speech_key", "")
|
||||
service_region = config.azure.get("speech_region", "")
|
||||
audio_config = speechsdk.audio.AudioOutputConfig(filename=voice_file, use_default_speaker=True)
|
||||
speech_config = speechsdk.SpeechConfig(subscription=speech_key,
|
||||
region=service_region)
|
||||
audio_config = speechsdk.audio.AudioOutputConfig(
|
||||
filename=voice_file, use_default_speaker=True
|
||||
)
|
||||
speech_config = speechsdk.SpeechConfig(
|
||||
subscription=speech_key, region=service_region
|
||||
)
|
||||
speech_config.speech_synthesis_voice_name = voice_name
|
||||
# speech_config.set_property(property_id=speechsdk.PropertyId.SpeechServiceResponse_RequestSentenceBoundary,
|
||||
# value='true')
|
||||
speech_config.set_property(property_id=speechsdk.PropertyId.SpeechServiceResponse_RequestWordBoundary,
|
||||
value='true')
|
||||
speech_config.set_property(
|
||||
property_id=speechsdk.PropertyId.SpeechServiceResponse_RequestWordBoundary,
|
||||
value="true",
|
||||
)
|
||||
|
||||
speech_config.set_speech_synthesis_output_format(
|
||||
speechsdk.SpeechSynthesisOutputFormat.Audio48Khz192KBitRateMonoMp3)
|
||||
speech_synthesizer = speechsdk.SpeechSynthesizer(audio_config=audio_config,
|
||||
speech_config=speech_config)
|
||||
speech_synthesizer.synthesis_word_boundary.connect(speech_synthesizer_word_boundary_cb)
|
||||
speechsdk.SpeechSynthesisOutputFormat.Audio48Khz192KBitRateMonoMp3
|
||||
)
|
||||
speech_synthesizer = speechsdk.SpeechSynthesizer(
|
||||
audio_config=audio_config, speech_config=speech_config
|
||||
)
|
||||
speech_synthesizer.synthesis_word_boundary.connect(
|
||||
speech_synthesizer_word_boundary_cb
|
||||
)
|
||||
|
||||
result = speech_synthesizer.speak_text_async(text).get()
|
||||
if result.reason == speechsdk.ResultReason.SynthesizingAudioCompleted:
|
||||
@@ -1140,9 +1159,13 @@ def azure_tts_v2(text: str, voice_name: str, voice_file: str) -> [SubMaker, None
|
||||
return sub_maker
|
||||
elif result.reason == speechsdk.ResultReason.Canceled:
|
||||
cancellation_details = result.cancellation_details
|
||||
logger.error(f"azure v2 speech synthesis canceled: {cancellation_details.reason}")
|
||||
logger.error(
|
||||
f"azure v2 speech synthesis canceled: {cancellation_details.reason}"
|
||||
)
|
||||
if cancellation_details.reason == speechsdk.CancellationReason.Error:
|
||||
logger.error(f"azure v2 speech synthesis error: {cancellation_details.error_details}")
|
||||
logger.error(
|
||||
f"azure v2 speech synthesis error: {cancellation_details.error_details}"
|
||||
)
|
||||
logger.info(f"completed, output file: {voice_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"failed, error: {str(e)}")
|
||||
@@ -1179,11 +1202,7 @@ def create_subtitle(sub_maker: submaker.SubMaker, text: str, subtitle_file: str)
|
||||
"""
|
||||
start_t = mktimestamp(start_time).replace(".", ",")
|
||||
end_t = mktimestamp(end_time).replace(".", ",")
|
||||
return (
|
||||
f"{idx}\n"
|
||||
f"{start_t} --> {end_t}\n"
|
||||
f"{sub_text}\n"
|
||||
)
|
||||
return f"{idx}\n" f"{start_t} --> {end_t}\n" f"{sub_text}\n"
|
||||
|
||||
start_time = -1.0
|
||||
sub_items = []
|
||||
@@ -1240,12 +1259,16 @@ def create_subtitle(sub_maker: submaker.SubMaker, text: str, subtitle_file: str)
|
||||
try:
|
||||
sbs = subtitles.file_to_subtitles(subtitle_file, encoding="utf-8")
|
||||
duration = max([tb for ((ta, tb), txt) in sbs])
|
||||
logger.info(f"completed, subtitle file created: {subtitle_file}, duration: {duration}")
|
||||
logger.info(
|
||||
f"completed, subtitle file created: {subtitle_file}, duration: {duration}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"failed, error: {str(e)}")
|
||||
os.remove(subtitle_file)
|
||||
else:
|
||||
logger.warning(f"failed, sub_items len: {len(sub_items)}, script_lines len: {len(script_lines)}")
|
||||
logger.warning(
|
||||
f"failed, sub_items len: {len(sub_items)}, script_lines len: {len(script_lines)}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"failed, error: {str(e)}")
|
||||
@@ -1269,7 +1292,6 @@ if __name__ == "__main__":
|
||||
voices = get_all_azure_voices()
|
||||
print(len(voices))
|
||||
|
||||
|
||||
async def _do():
|
||||
temp_dir = utils.storage_dir("temp")
|
||||
|
||||
@@ -1318,12 +1340,13 @@ if __name__ == "__main__":
|
||||
for voice_name in voice_names:
|
||||
voice_file = f"{temp_dir}/tts-{voice_name}.mp3"
|
||||
subtitle_file = f"{temp_dir}/tts.mp3.srt"
|
||||
sub_maker = azure_tts_v2(text=text, voice_name=voice_name, voice_file=voice_file)
|
||||
sub_maker = azure_tts_v2(
|
||||
text=text, voice_name=voice_name, voice_file=voice_file
|
||||
)
|
||||
create_subtitle(sub_maker=sub_maker, text=text, subtitle_file=subtitle_file)
|
||||
audio_duration = get_audio_duration(sub_maker)
|
||||
print(f"voice: {voice_name}, audio duration: {audio_duration}s")
|
||||
|
||||
|
||||
loop = asyncio.get_event_loop_policy().get_event_loop()
|
||||
try:
|
||||
loop.run_until_complete(_do())
|
||||
|
||||
Reference in New Issue
Block a user