🐛 fix: fix the LLM logic

This commit is contained in:
yyhhyyyyyy
2024-12-12 14:29:14 +08:00
parent 85d446e2d0
commit 2d8cd23fe7
3 changed files with 270 additions and 249 deletions

View File

@@ -3,6 +3,7 @@ import logging
import re import re
from typing import List from typing import List
import g4f
from loguru import logger from loguru import logger
from openai import AzureOpenAI, OpenAI from openai import AzureOpenAI, OpenAI
from openai.types.chat import ChatCompletion from openai.types.chat import ChatCompletion
@@ -13,6 +14,7 @@ _max_retries = 5
def _generate_response(prompt: str) -> str: def _generate_response(prompt: str) -> str:
try:
content = "" content = ""
llm_provider = config.app.get("llm_provider", "openai") llm_provider = config.app.get("llm_provider", "openai")
logger.info(f"llm provider: {llm_provider}") logger.info(f"llm provider: {llm_provider}")
@@ -20,8 +22,6 @@ def _generate_response(prompt: str) -> str:
model_name = config.app.get("g4f_model_name", "") model_name = config.app.get("g4f_model_name", "")
if not model_name: if not model_name:
model_name = "gpt-3.5-turbo-16k-0613" model_name = "gpt-3.5-turbo-16k-0613"
import g4f
content = g4f.ChatCompletion.create( content = g4f.ChatCompletion.create(
model=model_name, model=model_name,
messages=[{"role": "user", "content": prompt}], messages=[{"role": "user", "content": prompt}],
@@ -250,6 +250,8 @@ def _generate_response(prompt: str) -> str:
) )
return content.replace("\n", "") return content.replace("\n", "")
except Exception as e:
return f"Error: {str(e)}"
def generate_script( def generate_script(
@@ -319,7 +321,9 @@ Generate a script for a video, depending on the subject of the video.
if i < _max_retries: if i < _max_retries:
logger.warning(f"failed to generate video script, trying again... {i + 1}") logger.warning(f"failed to generate video script, trying again... {i + 1}")
if "Error: " in final_script:
logger.error(f"failed to generate video script: {final_script}")
else:
logger.success(f"completed: \n{final_script}") logger.success(f"completed: \n{final_script}")
return final_script.strip() return final_script.strip()
@@ -358,6 +362,9 @@ Please note that you must use English for generating video search terms; Chinese
for i in range(_max_retries): for i in range(_max_retries):
try: try:
response = _generate_response(prompt) response = _generate_response(prompt)
if "Error: " in response:
logger.error(f"failed to generate video script: {response}")
return response
search_terms = json.loads(response) search_terms = json.loads(response)
if not isinstance(search_terms, list) or not all( if not isinstance(search_terms, list) or not all(
isinstance(term, str) for term in search_terms isinstance(term, str) for term in search_terms

View File

@@ -214,7 +214,7 @@ def start(task_id, params: VideoParams, stop_at: str = "video"):
# 1. Generate script # 1. Generate script
video_script = generate_script(task_id, params) video_script = generate_script(task_id, params)
if not video_script: if not video_script or "Error: " in video_script:
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED) sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
return return

View File

@@ -449,8 +449,12 @@ with left_panel:
selected_index = st.selectbox( selected_index = st.selectbox(
tr("Script Language"), tr("Script Language"),
index=0, index=0,
options=range(len(video_languages)), # 使用索引作为内部选项值 options=range(
format_func=lambda x: video_languages[x][0], # 显示给用户的是标签 len(video_languages)
), # Use the index as the internal option value
format_func=lambda x: video_languages[x][
0
], # The label is displayed to the user
) )
params.video_language = video_languages[selected_index][1] params.video_language = video_languages[selected_index][1]
@@ -462,9 +466,13 @@ with left_panel:
video_subject=params.video_subject, language=params.video_language video_subject=params.video_subject, language=params.video_language
) )
terms = llm.generate_terms(params.video_subject, script) terms = llm.generate_terms(params.video_subject, script)
if "Error: " in script:
st.error(tr(script))
elif "Error: " in terms:
st.error(tr(terms))
else:
st.session_state["video_script"] = script st.session_state["video_script"] = script
st.session_state["video_terms"] = ", ".join(terms) st.session_state["video_terms"] = ", ".join(terms)
params.video_script = st.text_area( params.video_script = st.text_area(
tr("Video Script"), value=st.session_state["video_script"], height=280 tr("Video Script"), value=st.session_state["video_script"], height=280
) )
@@ -475,6 +483,9 @@ with left_panel:
with st.spinner(tr("Generating Video Keywords")): with st.spinner(tr("Generating Video Keywords")):
terms = llm.generate_terms(params.video_subject, params.video_script) terms = llm.generate_terms(params.video_subject, params.video_script)
if "Error: " in terms:
st.error(tr(terms))
else:
st.session_state["video_terms"] = ", ".join(terms) st.session_state["video_terms"] = ", ".join(terms)
params.video_terms = st.text_area( params.video_terms = st.text_area(
@@ -522,8 +533,12 @@ with middle_panel:
selected_index = st.selectbox( selected_index = st.selectbox(
tr("Video Concat Mode"), tr("Video Concat Mode"),
index=1, index=1,
options=range(len(video_concat_modes)), # 使用索引作为内部选项值 options=range(
format_func=lambda x: video_concat_modes[x][0], # 显示给用户的是标签 len(video_concat_modes)
), # Use the index as the internal option value
format_func=lambda x: video_concat_modes[x][
0
], # The label is displayed to the user
) )
params.video_concat_mode = VideoConcatMode( params.video_concat_mode = VideoConcatMode(
video_concat_modes[selected_index][1] video_concat_modes[selected_index][1]
@@ -535,8 +550,12 @@ with middle_panel:
] ]
selected_index = st.selectbox( selected_index = st.selectbox(
tr("Video Ratio"), tr("Video Ratio"),
options=range(len(video_aspect_ratios)), # 使用索引作为内部选项值 options=range(
format_func=lambda x: video_aspect_ratios[x][0], # 显示给用户的是标签 len(video_aspect_ratios)
), # Use the index as the internal option value
format_func=lambda x: video_aspect_ratios[x][
0
], # The label is displayed to the user
) )
params.video_aspect = VideoAspect(video_aspect_ratios[selected_index][1]) params.video_aspect = VideoAspect(video_aspect_ratios[selected_index][1])
@@ -648,13 +667,17 @@ with middle_panel:
selected_index = st.selectbox( selected_index = st.selectbox(
tr("Background Music"), tr("Background Music"),
index=1, index=1,
options=range(len(bgm_options)), # 使用索引作为内部选项值 options=range(
format_func=lambda x: bgm_options[x][0], # 显示给用户的是标签 len(bgm_options)
), # Use the index as the internal option value
format_func=lambda x: bgm_options[x][
0
], # The label is displayed to the user
) )
# 获取选择的背景音乐类型 # Get the selected background music type
params.bgm_type = bgm_options[selected_index][1] params.bgm_type = bgm_options[selected_index][1]
# 根据选择显示或隐藏组件 # Show or hide components based on the selection
if params.bgm_type == "custom": if params.bgm_type == "custom":
custom_bgm_file = st.text_input(tr("Custom Background Music File")) custom_bgm_file = st.text_input(tr("Custom Background Music File"))
if custom_bgm_file and os.path.exists(custom_bgm_file): if custom_bgm_file and os.path.exists(custom_bgm_file):
@@ -733,15 +756,6 @@ if start_button:
scroll_to_bottom() scroll_to_bottom()
st.stop() st.stop()
if (
llm_provider != "g4f"
and llm_provider != "ollama"
and not config.app.get(f"{llm_provider}_api_key", "")
):
st.error(tr("Please Enter the LLM API Key"))
scroll_to_bottom()
st.stop()
if params.video_source not in ["pexels", "pixabay", "local"]: if params.video_source not in ["pexels", "pixabay", "local"]:
st.error(tr("Please Select a Valid Video Source")) st.error(tr("Please Select a Valid Video Source"))
scroll_to_bottom() scroll_to_bottom()