Merge pull request #554 from yyhhyyyyyy/llm-logic
🐛 fix: fix the LLM logic
This commit is contained in:
@@ -3,6 +3,7 @@ import logging
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
import g4f
|
||||
from loguru import logger
|
||||
from openai import AzureOpenAI, OpenAI
|
||||
from openai.types.chat import ChatCompletion
|
||||
@@ -13,6 +14,7 @@ _max_retries = 5
|
||||
|
||||
|
||||
def _generate_response(prompt: str) -> str:
|
||||
try:
|
||||
content = ""
|
||||
llm_provider = config.app.get("llm_provider", "openai")
|
||||
logger.info(f"llm provider: {llm_provider}")
|
||||
@@ -20,8 +22,6 @@ def _generate_response(prompt: str) -> str:
|
||||
model_name = config.app.get("g4f_model_name", "")
|
||||
if not model_name:
|
||||
model_name = "gpt-3.5-turbo-16k-0613"
|
||||
import g4f
|
||||
|
||||
content = g4f.ChatCompletion.create(
|
||||
model=model_name,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
@@ -179,7 +179,10 @@ def _generate_response(prompt: str) -> str:
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
json={
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a friendly assistant"},
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a friendly assistant",
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
},
|
||||
@@ -197,7 +200,9 @@ def _generate_response(prompt: str) -> str:
|
||||
"client_secret": secret_key,
|
||||
}
|
||||
access_token = (
|
||||
requests.post("https://aip.baidubce.com/oauth/2.0/token", params=params)
|
||||
requests.post(
|
||||
"https://aip.baidubce.com/oauth/2.0/token", params=params
|
||||
)
|
||||
.json()
|
||||
.get("access_token")
|
||||
)
|
||||
@@ -250,6 +255,8 @@ def _generate_response(prompt: str) -> str:
|
||||
)
|
||||
|
||||
return content.replace("\n", "")
|
||||
except Exception as e:
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
|
||||
def generate_script(
|
||||
@@ -319,7 +326,9 @@ Generate a script for a video, depending on the subject of the video.
|
||||
|
||||
if i < _max_retries:
|
||||
logger.warning(f"failed to generate video script, trying again... {i + 1}")
|
||||
|
||||
if "Error: " in final_script:
|
||||
logger.error(f"failed to generate video script: {final_script}")
|
||||
else:
|
||||
logger.success(f"completed: \n{final_script}")
|
||||
return final_script.strip()
|
||||
|
||||
@@ -358,6 +367,9 @@ Please note that you must use English for generating video search terms; Chinese
|
||||
for i in range(_max_retries):
|
||||
try:
|
||||
response = _generate_response(prompt)
|
||||
if "Error: " in response:
|
||||
logger.error(f"failed to generate video script: {response}")
|
||||
return response
|
||||
search_terms = json.loads(response)
|
||||
if not isinstance(search_terms, list) or not all(
|
||||
isinstance(term, str) for term in search_terms
|
||||
|
||||
@@ -214,7 +214,7 @@ def start(task_id, params: VideoParams, stop_at: str = "video"):
|
||||
|
||||
# 1. Generate script
|
||||
video_script = generate_script(task_id, params)
|
||||
if not video_script:
|
||||
if not video_script or "Error: " in video_script:
|
||||
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
|
||||
return
|
||||
|
||||
|
||||
@@ -449,8 +449,12 @@ with left_panel:
|
||||
selected_index = st.selectbox(
|
||||
tr("Script Language"),
|
||||
index=0,
|
||||
options=range(len(video_languages)), # 使用索引作为内部选项值
|
||||
format_func=lambda x: video_languages[x][0], # 显示给用户的是标签
|
||||
options=range(
|
||||
len(video_languages)
|
||||
), # Use the index as the internal option value
|
||||
format_func=lambda x: video_languages[x][
|
||||
0
|
||||
], # The label is displayed to the user
|
||||
)
|
||||
params.video_language = video_languages[selected_index][1]
|
||||
|
||||
@@ -462,9 +466,13 @@ with left_panel:
|
||||
video_subject=params.video_subject, language=params.video_language
|
||||
)
|
||||
terms = llm.generate_terms(params.video_subject, script)
|
||||
if "Error: " in script:
|
||||
st.error(tr(script))
|
||||
elif "Error: " in terms:
|
||||
st.error(tr(terms))
|
||||
else:
|
||||
st.session_state["video_script"] = script
|
||||
st.session_state["video_terms"] = ", ".join(terms)
|
||||
|
||||
params.video_script = st.text_area(
|
||||
tr("Video Script"), value=st.session_state["video_script"], height=280
|
||||
)
|
||||
@@ -475,6 +483,9 @@ with left_panel:
|
||||
|
||||
with st.spinner(tr("Generating Video Keywords")):
|
||||
terms = llm.generate_terms(params.video_subject, params.video_script)
|
||||
if "Error: " in terms:
|
||||
st.error(tr(terms))
|
||||
else:
|
||||
st.session_state["video_terms"] = ", ".join(terms)
|
||||
|
||||
params.video_terms = st.text_area(
|
||||
@@ -522,8 +533,12 @@ with middle_panel:
|
||||
selected_index = st.selectbox(
|
||||
tr("Video Concat Mode"),
|
||||
index=1,
|
||||
options=range(len(video_concat_modes)), # 使用索引作为内部选项值
|
||||
format_func=lambda x: video_concat_modes[x][0], # 显示给用户的是标签
|
||||
options=range(
|
||||
len(video_concat_modes)
|
||||
), # Use the index as the internal option value
|
||||
format_func=lambda x: video_concat_modes[x][
|
||||
0
|
||||
], # The label is displayed to the user
|
||||
)
|
||||
params.video_concat_mode = VideoConcatMode(
|
||||
video_concat_modes[selected_index][1]
|
||||
@@ -535,8 +550,12 @@ with middle_panel:
|
||||
]
|
||||
selected_index = st.selectbox(
|
||||
tr("Video Ratio"),
|
||||
options=range(len(video_aspect_ratios)), # 使用索引作为内部选项值
|
||||
format_func=lambda x: video_aspect_ratios[x][0], # 显示给用户的是标签
|
||||
options=range(
|
||||
len(video_aspect_ratios)
|
||||
), # Use the index as the internal option value
|
||||
format_func=lambda x: video_aspect_ratios[x][
|
||||
0
|
||||
], # The label is displayed to the user
|
||||
)
|
||||
params.video_aspect = VideoAspect(video_aspect_ratios[selected_index][1])
|
||||
|
||||
@@ -648,13 +667,17 @@ with middle_panel:
|
||||
selected_index = st.selectbox(
|
||||
tr("Background Music"),
|
||||
index=1,
|
||||
options=range(len(bgm_options)), # 使用索引作为内部选项值
|
||||
format_func=lambda x: bgm_options[x][0], # 显示给用户的是标签
|
||||
options=range(
|
||||
len(bgm_options)
|
||||
), # Use the index as the internal option value
|
||||
format_func=lambda x: bgm_options[x][
|
||||
0
|
||||
], # The label is displayed to the user
|
||||
)
|
||||
# 获取选择的背景音乐类型
|
||||
# Get the selected background music type
|
||||
params.bgm_type = bgm_options[selected_index][1]
|
||||
|
||||
# 根据选择显示或隐藏组件
|
||||
# Show or hide components based on the selection
|
||||
if params.bgm_type == "custom":
|
||||
custom_bgm_file = st.text_input(tr("Custom Background Music File"))
|
||||
if custom_bgm_file and os.path.exists(custom_bgm_file):
|
||||
@@ -733,15 +756,6 @@ if start_button:
|
||||
scroll_to_bottom()
|
||||
st.stop()
|
||||
|
||||
if (
|
||||
llm_provider != "g4f"
|
||||
and llm_provider != "ollama"
|
||||
and not config.app.get(f"{llm_provider}_api_key", "")
|
||||
):
|
||||
st.error(tr("Please Enter the LLM API Key"))
|
||||
scroll_to_bottom()
|
||||
st.stop()
|
||||
|
||||
if params.video_source not in ["pexels", "pixabay", "local"]:
|
||||
st.error(tr("Please Select a Valid Video Source"))
|
||||
scroll_to_bottom()
|
||||
|
||||
Reference in New Issue
Block a user