Format project code

This commit is contained in:
yyhhyyyyyy
2024-07-24 14:59:06 +08:00
parent bbd4e94941
commit 905841965a
18 changed files with 410 additions and 214 deletions

View File

@@ -21,6 +21,7 @@ def _generate_response(prompt: str) -> str:
if not model_name:
model_name = "gpt-3.5-turbo-16k-0613"
import g4f
content = g4f.ChatCompletion.create(
model=model_name,
messages=[{"role": "user", "content": prompt}],
@@ -78,44 +79,56 @@ def _generate_response(prompt: str) -> str:
base_url = config.app.get("ernie_base_url")
model_name = "***"
if not secret_key:
raise ValueError(f"{llm_provider}: secret_key is not set, please set it in the config.toml file.")
raise ValueError(
f"{llm_provider}: secret_key is not set, please set it in the config.toml file."
)
else:
raise ValueError("llm_provider is not set, please set it in the config.toml file.")
raise ValueError(
"llm_provider is not set, please set it in the config.toml file."
)
if not api_key:
raise ValueError(f"{llm_provider}: api_key is not set, please set it in the config.toml file.")
raise ValueError(
f"{llm_provider}: api_key is not set, please set it in the config.toml file."
)
if not model_name:
raise ValueError(f"{llm_provider}: model_name is not set, please set it in the config.toml file.")
raise ValueError(
f"{llm_provider}: model_name is not set, please set it in the config.toml file."
)
if not base_url:
raise ValueError(f"{llm_provider}: base_url is not set, please set it in the config.toml file.")
raise ValueError(
f"{llm_provider}: base_url is not set, please set it in the config.toml file."
)
if llm_provider == "qwen":
import dashscope
from dashscope.api_entities.dashscope_response import GenerationResponse
dashscope.api_key = api_key
response = dashscope.Generation.call(
model=model_name,
messages=[{"role": "user", "content": prompt}]
model=model_name, messages=[{"role": "user", "content": prompt}]
)
if response:
if isinstance(response, GenerationResponse):
status_code = response.status_code
if status_code != 200:
raise Exception(
f"[{llm_provider}] returned an error response: \"{response}\"")
f'[{llm_provider}] returned an error response: "{response}"'
)
content = response["output"]["text"]
return content.replace("\n", "")
else:
raise Exception(
f"[{llm_provider}] returned an invalid response: \"{response}\"")
f'[{llm_provider}] returned an invalid response: "{response}"'
)
else:
raise Exception(
f"[{llm_provider}] returned an empty response")
raise Exception(f"[{llm_provider}] returned an empty response")
if llm_provider == "gemini":
import google.generativeai as genai
genai.configure(api_key=api_key, transport='rest')
genai.configure(api_key=api_key, transport="rest")
generation_config = {
"temperature": 0.5,
@@ -127,25 +140,27 @@ def _generate_response(prompt: str) -> str:
safety_settings = [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_ONLY_HIGH"
"threshold": "BLOCK_ONLY_HIGH",
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_ONLY_HIGH"
"threshold": "BLOCK_ONLY_HIGH",
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_ONLY_HIGH"
"threshold": "BLOCK_ONLY_HIGH",
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_ONLY_HIGH"
"threshold": "BLOCK_ONLY_HIGH",
},
]
model = genai.GenerativeModel(model_name=model_name,
generation_config=generation_config,
safety_settings=safety_settings)
model = genai.GenerativeModel(
model_name=model_name,
generation_config=generation_config,
safety_settings=safety_settings,
)
try:
response = model.generate_content(prompt)
@@ -158,15 +173,16 @@ def _generate_response(prompt: str) -> str:
if llm_provider == "cloudflare":
import requests
response = requests.post(
f"https://api.cloudflare.com/client/v4/accounts/{account_id}/ai/run/{model_name}",
headers={"Authorization": f"Bearer {api_key}"},
json={
"messages": [
{"role": "system", "content": "You are a friendly assistant"},
{"role": "user", "content": prompt}
{"role": "user", "content": prompt},
]
}
},
)
result = response.json()
logger.info(result)
@@ -174,30 +190,35 @@ def _generate_response(prompt: str) -> str:
if llm_provider == "ernie":
import requests
params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key}
access_token = requests.post("https://aip.baidubce.com/oauth/2.0/token", params=params).json().get(
"access_token")
params = {
"grant_type": "client_credentials",
"client_id": api_key,
"client_secret": secret_key,
}
access_token = (
requests.post("https://aip.baidubce.com/oauth/2.0/token", params=params)
.json()
.get("access_token")
)
url = f"{base_url}?access_token={access_token}"
payload = json.dumps({
"messages": [
{
"role": "user",
"content": prompt
}
],
"temperature": 0.5,
"top_p": 0.8,
"penalty_score": 1,
"disable_search": False,
"enable_citation": False,
"response_format": "text"
})
headers = {
'Content-Type': 'application/json'
}
payload = json.dumps(
{
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.5,
"top_p": 0.8,
"penalty_score": 1,
"disable_search": False,
"enable_citation": False,
"response_format": "text",
}
)
headers = {"Content-Type": "application/json"}
response = requests.request("POST", url, headers=headers, data=payload).json()
response = requests.request(
"POST", url, headers=headers, data=payload
).json()
return response.get("result")
if llm_provider == "azure":
@@ -213,24 +234,27 @@ def _generate_response(prompt: str) -> str:
)
response = client.chat.completions.create(
model=model_name,
messages=[{"role": "user", "content": prompt}]
model=model_name, messages=[{"role": "user", "content": prompt}]
)
if response:
if isinstance(response, ChatCompletion):
content = response.choices[0].message.content
else:
raise Exception(
f"[{llm_provider}] returned an invalid response: \"{response}\", please check your network "
f"connection and try again.")
f'[{llm_provider}] returned an invalid response: "{response}", please check your network '
f"connection and try again."
)
else:
raise Exception(
f"[{llm_provider}] returned an empty response, please check your network connection and try again.")
f"[{llm_provider}] returned an empty response, please check your network connection and try again."
)
return content.replace("\n", "")
def generate_script(video_subject: str, language: str = "", paragraph_number: int = 1) -> str:
def generate_script(
video_subject: str, language: str = "", paragraph_number: int = 1
) -> str:
prompt = f"""
# Role: Video Script Generator
@@ -335,14 +359,16 @@ Please note that you must use English for generating video search terms; Chinese
try:
response = _generate_response(prompt)
search_terms = json.loads(response)
if not isinstance(search_terms, list) or not all(isinstance(term, str) for term in search_terms):
if not isinstance(search_terms, list) or not all(
isinstance(term, str) for term in search_terms
):
logger.error("response is not a list of strings.")
continue
except Exception as e:
logger.warning(f"failed to generate video terms: {str(e)}")
if response:
match = re.search(r'\[.*]', response)
match = re.search(r"\[.*]", response)
if match:
try:
search_terms = json.loads(match.group())
@@ -361,9 +387,13 @@ Please note that you must use English for generating video search terms; Chinese
if __name__ == "__main__":
video_subject = "生命的意义是什么"
script = generate_script(video_subject=video_subject, language="zh-CN", paragraph_number=1)
script = generate_script(
video_subject=video_subject, language="zh-CN", paragraph_number=1
)
print("######################")
print(script)
search_terms = generate_terms(video_subject=video_subject, video_script=script, amount=5)
search_terms = generate_terms(
video_subject=video_subject, video_script=script, amount=5
)
print("######################")
print(search_terms)

View File

@@ -19,7 +19,8 @@ def get_api_key(cfg_key: str):
if not api_keys:
raise ValueError(
f"\n\n##### {cfg_key} is not set #####\n\nPlease set it in the config.toml file: {config.config_file}\n\n"
f"{utils.to_json(config.app)}")
f"{utils.to_json(config.app)}"
)
# if only one key is provided, return it
if isinstance(api_keys, str):
@@ -30,28 +31,29 @@ def get_api_key(cfg_key: str):
return api_keys[requested_count % len(api_keys)]
def search_videos_pexels(search_term: str,
minimum_duration: int,
video_aspect: VideoAspect = VideoAspect.portrait,
) -> List[MaterialInfo]:
def search_videos_pexels(
search_term: str,
minimum_duration: int,
video_aspect: VideoAspect = VideoAspect.portrait,
) -> List[MaterialInfo]:
aspect = VideoAspect(video_aspect)
video_orientation = aspect.name
video_width, video_height = aspect.to_resolution()
api_key = get_api_key("pexels_api_keys")
headers = {
"Authorization": api_key
}
headers = {"Authorization": api_key}
# Build URL
params = {
"query": search_term,
"per_page": 20,
"orientation": video_orientation
}
params = {"query": search_term, "per_page": 20, "orientation": video_orientation}
query_url = f"https://api.pexels.com/videos/search?{urlencode(params)}"
logger.info(f"searching videos: {query_url}, with proxies: {config.proxy}")
try:
r = requests.get(query_url, headers=headers, proxies=config.proxy, verify=False, timeout=(30, 60))
r = requests.get(
query_url,
headers=headers,
proxies=config.proxy,
verify=False,
timeout=(30, 60),
)
response = r.json()
video_items = []
if "videos" not in response:
@@ -83,10 +85,11 @@ def search_videos_pexels(search_term: str,
return []
def search_videos_pixabay(search_term: str,
minimum_duration: int,
video_aspect: VideoAspect = VideoAspect.portrait,
) -> List[MaterialInfo]:
def search_videos_pixabay(
search_term: str,
minimum_duration: int,
video_aspect: VideoAspect = VideoAspect.portrait,
) -> List[MaterialInfo]:
aspect = VideoAspect(video_aspect)
video_width, video_height = aspect.to_resolution()
@@ -97,13 +100,15 @@ def search_videos_pixabay(search_term: str,
"q": search_term,
"video_type": "all", # Accepted values: "all", "film", "animation"
"per_page": 50,
"key": api_key
"key": api_key,
}
query_url = f"https://pixabay.com/api/videos/?{urlencode(params)}"
logger.info(f"searching videos: {query_url}, with proxies: {config.proxy}")
try:
r = requests.get(query_url, proxies=config.proxy, verify=False, timeout=(30, 60))
r = requests.get(
query_url, proxies=config.proxy, verify=False, timeout=(30, 60)
)
response = r.json()
video_items = []
if "hits" not in response:
@@ -155,7 +160,11 @@ def save_video(video_url: str, save_dir: str = "") -> str:
# if video does not exist, download it
with open(video_path, "wb") as f:
f.write(requests.get(video_url, proxies=config.proxy, verify=False, timeout=(60, 240)).content)
f.write(
requests.get(
video_url, proxies=config.proxy, verify=False, timeout=(60, 240)
).content
)
if os.path.exists(video_path) and os.path.getsize(video_path) > 0:
try:
@@ -174,14 +183,15 @@ def save_video(video_url: str, save_dir: str = "") -> str:
return ""
def download_videos(task_id: str,
search_terms: List[str],
source: str = "pexels",
video_aspect: VideoAspect = VideoAspect.portrait,
video_contact_mode: VideoConcatMode = VideoConcatMode.random,
audio_duration: float = 0.0,
max_clip_duration: int = 5,
) -> List[str]:
def download_videos(
task_id: str,
search_terms: List[str],
source: str = "pexels",
video_aspect: VideoAspect = VideoAspect.portrait,
video_contact_mode: VideoConcatMode = VideoConcatMode.random,
audio_duration: float = 0.0,
max_clip_duration: int = 5,
) -> List[str]:
valid_video_items = []
valid_video_urls = []
found_duration = 0.0
@@ -190,9 +200,11 @@ def download_videos(task_id: str,
search_videos = search_videos_pixabay
for search_term in search_terms:
video_items = search_videos(search_term=search_term,
minimum_duration=max_clip_duration,
video_aspect=video_aspect)
video_items = search_videos(
search_term=search_term,
minimum_duration=max_clip_duration,
video_aspect=video_aspect,
)
logger.info(f"found {len(video_items)} videos for '{search_term}'")
for item in video_items:
@@ -202,7 +214,8 @@ def download_videos(task_id: str,
found_duration += item.duration
logger.info(
f"found total videos: {len(valid_video_items)}, required duration: {audio_duration} seconds, found duration: {found_duration} seconds")
f"found total videos: {len(valid_video_items)}, required duration: {audio_duration} seconds, found duration: {found_duration} seconds"
)
video_paths = []
material_directory = config.app.get("material_directory", "").strip()
@@ -218,14 +231,18 @@ def download_videos(task_id: str,
for item in valid_video_items:
try:
logger.info(f"downloading video: {item.url}")
saved_video_path = save_video(video_url=item.url, save_dir=material_directory)
saved_video_path = save_video(
video_url=item.url, save_dir=material_directory
)
if saved_video_path:
logger.info(f"video saved: {saved_video_path}")
video_paths.append(saved_video_path)
seconds = min(max_clip_duration, item.duration)
total_duration += seconds
if total_duration > audio_duration:
logger.info(f"total duration of downloaded videos: {total_duration} seconds, skip downloading more")
logger.info(
f"total duration of downloaded videos: {total_duration} seconds, skip downloading more"
)
break
except Exception as e:
logger.error(f"failed to download video: {utils.to_json(item)} => {str(e)}")
@@ -234,4 +251,6 @@ def download_videos(task_id: str,
if __name__ == "__main__":
download_videos("test123", ["Money Exchange Medium"], audio_duration=100, source="pixabay")
download_videos(
"test123", ["Money Exchange Medium"], audio_duration=100, source="pixabay"
)

View File

@@ -6,7 +6,6 @@ from app.models import const
# Base class for state management
class BaseState(ABC):
@abstractmethod
def update_task(self, task_id: str, state: int, progress: int = 0, **kwargs):
pass
@@ -18,11 +17,16 @@ class BaseState(ABC):
# Memory state management
class MemoryState(BaseState):
def __init__(self):
self._tasks = {}
def update_task(self, task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs):
def update_task(
self,
task_id: str,
state: int = const.TASK_STATE_PROCESSING,
progress: int = 0,
**kwargs,
):
progress = int(progress)
if progress > 100:
progress = 100
@@ -43,12 +47,18 @@ class MemoryState(BaseState):
# Redis state management
class RedisState(BaseState):
def __init__(self, host='localhost', port=6379, db=0, password=None):
def __init__(self, host="localhost", port=6379, db=0, password=None):
import redis
self._redis = redis.StrictRedis(host=host, port=port, db=db, password=password)
def update_task(self, task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs):
def update_task(
self,
task_id: str,
state: int = const.TASK_STATE_PROCESSING,
progress: int = 0,
**kwargs,
):
progress = int(progress)
if progress > 100:
progress = 100
@@ -67,7 +77,10 @@ class RedisState(BaseState):
if not task_data:
return None
task = {key.decode('utf-8'): self._convert_to_original_type(value) for key, value in task_data.items()}
task = {
key.decode("utf-8"): self._convert_to_original_type(value)
for key, value in task_data.items()
}
return task
def delete_task(self, task_id: str):
@@ -79,7 +92,7 @@ class RedisState(BaseState):
Convert the value from byte string to its original data type.
You can extend this method to handle other data types as needed.
"""
value_str = value.decode('utf-8')
value_str = value.decode("utf-8")
try:
# try to convert byte string array to list
@@ -100,4 +113,10 @@ _redis_port = config.app.get("redis_port", 6379)
_redis_db = config.app.get("redis_db", 0)
_redis_password = config.app.get("redis_password", None)
state = RedisState(host=_redis_host, port=_redis_port, db=_redis_db, password=_redis_password) if _enable_redis else MemoryState()
state = (
RedisState(
host=_redis_host, port=_redis_port, db=_redis_db, password=_redis_password
)
if _enable_redis
else MemoryState()
)

View File

@@ -23,18 +23,22 @@ def create(audio_file, subtitle_file: str = ""):
if not os.path.isdir(model_path) or not os.path.isfile(model_bin_file):
model_path = model_size
logger.info(f"loading model: {model_path}, device: {device}, compute_type: {compute_type}")
logger.info(
f"loading model: {model_path}, device: {device}, compute_type: {compute_type}"
)
try:
model = WhisperModel(model_size_or_path=model_path,
device=device,
compute_type=compute_type)
model = WhisperModel(
model_size_or_path=model_path, device=device, compute_type=compute_type
)
except Exception as e:
logger.error(f"failed to load model: {e} \n\n"
f"********************************************\n"
f"this may be caused by network issue. \n"
f"please download the model manually and put it in the 'models' folder. \n"
f"see [README.md FAQ](https://github.com/harry0703/MoneyPrinterTurbo) for more details.\n"
f"********************************************\n\n")
logger.error(
f"failed to load model: {e} \n\n"
f"********************************************\n"
f"this may be caused by network issue. \n"
f"please download the model manually and put it in the 'models' folder. \n"
f"see [README.md FAQ](https://github.com/harry0703/MoneyPrinterTurbo) for more details.\n"
f"********************************************\n\n"
)
return None
logger.info(f"start, output file: {subtitle_file}")
@@ -49,7 +53,9 @@ def create(audio_file, subtitle_file: str = ""):
vad_parameters=dict(min_silence_duration_ms=500),
)
logger.info(f"detected language: '{info.language}', probability: {info.language_probability:.2f}")
logger.info(
f"detected language: '{info.language}', probability: {info.language_probability:.2f}"
)
start = timer()
subtitles = []
@@ -62,11 +68,9 @@ def create(audio_file, subtitle_file: str = ""):
msg = "[%.2fs -> %.2fs] %s" % (seg_start, seg_end, seg_text)
logger.debug(msg)
subtitles.append({
"msg": seg_text,
"start_time": seg_start,
"end_time": seg_end
})
subtitles.append(
{"msg": seg_text, "start_time": seg_start, "end_time": seg_end}
)
for segment in segments:
words_idx = 0
@@ -119,7 +123,11 @@ def create(audio_file, subtitle_file: str = ""):
for subtitle in subtitles:
text = subtitle.get("msg")
if text:
lines.append(utils.text_to_srt(idx, text, subtitle.get("start_time"), subtitle.get("end_time")))
lines.append(
utils.text_to_srt(
idx, text, subtitle.get("start_time"), subtitle.get("end_time")
)
)
idx += 1
sub = "\n".join(lines) + "\n"
@@ -136,12 +144,12 @@ def file_to_subtitles(filename):
current_times = None
current_text = ""
index = 0
with open(filename, 'r', encoding="utf-8") as f:
with open(filename, "r", encoding="utf-8") as f:
for line in f:
times = re.findall("([0-9]*:[0-9]*:[0-9]*,[0-9]*)", line)
if times:
current_times = line
elif line.strip() == '' and current_times:
elif line.strip() == "" and current_times:
index += 1
times_texts.append((index, current_times.strip(), current_text.strip()))
current_times, current_text = None, ""
@@ -166,9 +174,10 @@ def levenshtein_distance(s1, s2):
substitutions = previous_row[j] + (c1 != c2)
current_row.append(min(insertions, deletions, substitutions))
previous_row = current_row
return previous_row[-1]
def similarity(a, b):
distance = levenshtein_distance(a.lower(), b.lower())
max_length = max(len(a), len(b))
@@ -194,26 +203,44 @@ def correct(subtitle_file, video_script):
subtitle_index += 1
else:
combined_subtitle = subtitle_line
start_time = subtitle_items[subtitle_index][1].split(' --> ')[0]
end_time = subtitle_items[subtitle_index][1].split(' --> ')[1]
start_time = subtitle_items[subtitle_index][1].split(" --> ")[0]
end_time = subtitle_items[subtitle_index][1].split(" --> ")[1]
next_subtitle_index = subtitle_index + 1
while next_subtitle_index < len(subtitle_items):
next_subtitle = subtitle_items[next_subtitle_index][2].strip()
if similarity(script_line, combined_subtitle + " " + next_subtitle) > similarity(script_line, combined_subtitle):
if similarity(
script_line, combined_subtitle + " " + next_subtitle
) > similarity(script_line, combined_subtitle):
combined_subtitle += " " + next_subtitle
end_time = subtitle_items[next_subtitle_index][1].split(' --> ')[1]
end_time = subtitle_items[next_subtitle_index][1].split(" --> ")[1]
next_subtitle_index += 1
else:
break
if similarity(script_line, combined_subtitle) > 0.8:
logger.warning(f"Merged/Corrected - Script: {script_line}, Subtitle: {combined_subtitle}")
new_subtitle_items.append((len(new_subtitle_items) + 1, f"{start_time} --> {end_time}", script_line))
logger.warning(
f"Merged/Corrected - Script: {script_line}, Subtitle: {combined_subtitle}"
)
new_subtitle_items.append(
(
len(new_subtitle_items) + 1,
f"{start_time} --> {end_time}",
script_line,
)
)
corrected = True
else:
logger.warning(f"Mismatch - Script: {script_line}, Subtitle: {combined_subtitle}")
new_subtitle_items.append((len(new_subtitle_items) + 1, f"{start_time} --> {end_time}", script_line))
logger.warning(
f"Mismatch - Script: {script_line}, Subtitle: {combined_subtitle}"
)
new_subtitle_items.append(
(
len(new_subtitle_items) + 1,
f"{start_time} --> {end_time}",
script_line,
)
)
corrected = True
script_index += 1
@@ -223,10 +250,22 @@ def correct(subtitle_file, video_script):
while script_index < len(script_lines):
logger.warning(f"Extra script line: {script_lines[script_index]}")
if subtitle_index < len(subtitle_items):
new_subtitle_items.append((len(new_subtitle_items) + 1, subtitle_items[subtitle_index][1], script_lines[script_index]))
new_subtitle_items.append(
(
len(new_subtitle_items) + 1,
subtitle_items[subtitle_index][1],
script_lines[script_index],
)
)
subtitle_index += 1
else:
new_subtitle_items.append((len(new_subtitle_items) + 1, "00:00:00,000 --> 00:00:00,000", script_lines[script_index]))
new_subtitle_items.append(
(
len(new_subtitle_items) + 1,
"00:00:00,000 --> 00:00:00,000",
script_lines[script_index],
)
)
script_index += 1
corrected = True

View File

@@ -988,7 +988,7 @@ Name: zh-CN-XiaoxiaoMultilingualNeural-V2
Gender: Female
""".strip()
voices = []
name = ''
name = ""
for line in voices_str.split("\n"):
line = line.strip()
if not line:
@@ -1008,7 +1008,7 @@ Gender: Female
voices.append(f"{name}-{gender}")
else:
voices.append(f"{name}-{gender}")
name = ''
name = ""
voices.sort()
return voices
@@ -1028,7 +1028,9 @@ def is_azure_v2_voice(voice_name: str):
return ""
def tts(text: str, voice_name: str, voice_rate: float, voice_file: str) -> [SubMaker, None]:
def tts(
text: str, voice_name: str, voice_rate: float, voice_file: str
) -> [SubMaker, None]:
if is_azure_v2_voice(voice_name):
return azure_tts_v2(text, voice_name, voice_file)
return azure_tts_v1(text, voice_name, voice_rate, voice_file)
@@ -1042,9 +1044,11 @@ def convert_rate_to_percent(rate: float) -> str:
return f"+{percent}%"
else:
return f"{percent}%"
def azure_tts_v1(text: str, voice_name: str, voice_rate: float, voice_file: str) -> [SubMaker, None]:
def azure_tts_v1(
text: str, voice_name: str, voice_rate: float, voice_file: str
) -> [SubMaker, None]:
voice_name = parse_voice_name(voice_name)
text = text.strip()
rate_str = convert_rate_to_percent(voice_rate)
@@ -1060,7 +1064,9 @@ def azure_tts_v1(text: str, voice_name: str, voice_rate: float, voice_file: str)
if chunk["type"] == "audio":
file.write(chunk["data"])
elif chunk["type"] == "WordBoundary":
sub_maker.create_sub((chunk["offset"], chunk["duration"]), chunk["text"])
sub_maker.create_sub(
(chunk["offset"], chunk["duration"]), chunk["text"]
)
return sub_maker
sub_maker = asyncio.run(_do())
@@ -1085,8 +1091,12 @@ def azure_tts_v2(text: str, voice_name: str, voice_file: str) -> [SubMaker, None
def _format_duration_to_offset(duration) -> int:
if isinstance(duration, str):
time_obj = datetime.strptime(duration, "%H:%M:%S.%f")
milliseconds = (time_obj.hour * 3600000) + (time_obj.minute * 60000) + (time_obj.second * 1000) + (
time_obj.microsecond // 1000)
milliseconds = (
(time_obj.hour * 3600000)
+ (time_obj.minute * 60000)
+ (time_obj.second * 1000)
+ (time_obj.microsecond // 1000)
)
return milliseconds * 10000
if isinstance(duration, int):
@@ -1119,20 +1129,29 @@ def azure_tts_v2(text: str, voice_name: str, voice_file: str) -> [SubMaker, None
# Creates an instance of a speech config with specified subscription key and service region.
speech_key = config.azure.get("speech_key", "")
service_region = config.azure.get("speech_region", "")
audio_config = speechsdk.audio.AudioOutputConfig(filename=voice_file, use_default_speaker=True)
speech_config = speechsdk.SpeechConfig(subscription=speech_key,
region=service_region)
audio_config = speechsdk.audio.AudioOutputConfig(
filename=voice_file, use_default_speaker=True
)
speech_config = speechsdk.SpeechConfig(
subscription=speech_key, region=service_region
)
speech_config.speech_synthesis_voice_name = voice_name
# speech_config.set_property(property_id=speechsdk.PropertyId.SpeechServiceResponse_RequestSentenceBoundary,
# value='true')
speech_config.set_property(property_id=speechsdk.PropertyId.SpeechServiceResponse_RequestWordBoundary,
value='true')
speech_config.set_property(
property_id=speechsdk.PropertyId.SpeechServiceResponse_RequestWordBoundary,
value="true",
)
speech_config.set_speech_synthesis_output_format(
speechsdk.SpeechSynthesisOutputFormat.Audio48Khz192KBitRateMonoMp3)
speech_synthesizer = speechsdk.SpeechSynthesizer(audio_config=audio_config,
speech_config=speech_config)
speech_synthesizer.synthesis_word_boundary.connect(speech_synthesizer_word_boundary_cb)
speechsdk.SpeechSynthesisOutputFormat.Audio48Khz192KBitRateMonoMp3
)
speech_synthesizer = speechsdk.SpeechSynthesizer(
audio_config=audio_config, speech_config=speech_config
)
speech_synthesizer.synthesis_word_boundary.connect(
speech_synthesizer_word_boundary_cb
)
result = speech_synthesizer.speak_text_async(text).get()
if result.reason == speechsdk.ResultReason.SynthesizingAudioCompleted:
@@ -1140,9 +1159,13 @@ def azure_tts_v2(text: str, voice_name: str, voice_file: str) -> [SubMaker, None
return sub_maker
elif result.reason == speechsdk.ResultReason.Canceled:
cancellation_details = result.cancellation_details
logger.error(f"azure v2 speech synthesis canceled: {cancellation_details.reason}")
logger.error(
f"azure v2 speech synthesis canceled: {cancellation_details.reason}"
)
if cancellation_details.reason == speechsdk.CancellationReason.Error:
logger.error(f"azure v2 speech synthesis error: {cancellation_details.error_details}")
logger.error(
f"azure v2 speech synthesis error: {cancellation_details.error_details}"
)
logger.info(f"completed, output file: {voice_file}")
except Exception as e:
logger.error(f"failed, error: {str(e)}")
@@ -1179,11 +1202,7 @@ def create_subtitle(sub_maker: submaker.SubMaker, text: str, subtitle_file: str)
"""
start_t = mktimestamp(start_time).replace(".", ",")
end_t = mktimestamp(end_time).replace(".", ",")
return (
f"{idx}\n"
f"{start_t} --> {end_t}\n"
f"{sub_text}\n"
)
return f"{idx}\n" f"{start_t} --> {end_t}\n" f"{sub_text}\n"
start_time = -1.0
sub_items = []
@@ -1240,12 +1259,16 @@ def create_subtitle(sub_maker: submaker.SubMaker, text: str, subtitle_file: str)
try:
sbs = subtitles.file_to_subtitles(subtitle_file, encoding="utf-8")
duration = max([tb for ((ta, tb), txt) in sbs])
logger.info(f"completed, subtitle file created: {subtitle_file}, duration: {duration}")
logger.info(
f"completed, subtitle file created: {subtitle_file}, duration: {duration}"
)
except Exception as e:
logger.error(f"failed, error: {str(e)}")
os.remove(subtitle_file)
else:
logger.warning(f"failed, sub_items len: {len(sub_items)}, script_lines len: {len(script_lines)}")
logger.warning(
f"failed, sub_items len: {len(sub_items)}, script_lines len: {len(script_lines)}"
)
except Exception as e:
logger.error(f"failed, error: {str(e)}")
@@ -1269,7 +1292,6 @@ if __name__ == "__main__":
voices = get_all_azure_voices()
print(len(voices))
async def _do():
temp_dir = utils.storage_dir("temp")
@@ -1318,12 +1340,13 @@ if __name__ == "__main__":
for voice_name in voice_names:
voice_file = f"{temp_dir}/tts-{voice_name}.mp3"
subtitle_file = f"{temp_dir}/tts.mp3.srt"
sub_maker = azure_tts_v2(text=text, voice_name=voice_name, voice_file=voice_file)
sub_maker = azure_tts_v2(
text=text, voice_name=voice_name, voice_file=voice_file
)
create_subtitle(sub_maker=sub_maker, text=text, subtitle_file=subtitle_file)
audio_duration = get_audio_duration(sub_maker)
print(f"voice: {voice_name}, audio duration: {audio_duration}s")
loop = asyncio.get_event_loop_policy().get_event_loop()
try:
loop.run_until_complete(_do())