🐛 fix: fix the LLM logic

This commit is contained in:
yyhhyyyyyy
2024-12-12 14:29:14 +08:00
parent 85d446e2d0
commit 2d8cd23fe7
3 changed files with 270 additions and 249 deletions

View File

@@ -3,6 +3,7 @@ import logging
import re import re
from typing import List from typing import List
import g4f
from loguru import logger from loguru import logger
from openai import AzureOpenAI, OpenAI from openai import AzureOpenAI, OpenAI
from openai.types.chat import ChatCompletion from openai.types.chat import ChatCompletion
@@ -13,243 +14,244 @@ _max_retries = 5
def _generate_response(prompt: str) -> str: def _generate_response(prompt: str) -> str:
content = "" try:
llm_provider = config.app.get("llm_provider", "openai") content = ""
logger.info(f"llm provider: {llm_provider}") llm_provider = config.app.get("llm_provider", "openai")
if llm_provider == "g4f": logger.info(f"llm provider: {llm_provider}")
model_name = config.app.get("g4f_model_name", "") if llm_provider == "g4f":
if not model_name: model_name = config.app.get("g4f_model_name", "")
model_name = "gpt-3.5-turbo-16k-0613" if not model_name:
import g4f model_name = "gpt-3.5-turbo-16k-0613"
content = g4f.ChatCompletion.create(
content = g4f.ChatCompletion.create( model=model_name,
model=model_name, messages=[{"role": "user", "content": prompt}],
messages=[{"role": "user", "content": prompt}], )
)
else:
api_version = "" # for azure
if llm_provider == "moonshot":
api_key = config.app.get("moonshot_api_key")
model_name = config.app.get("moonshot_model_name")
base_url = "https://api.moonshot.cn/v1"
elif llm_provider == "ollama":
# api_key = config.app.get("openai_api_key")
api_key = "ollama" # any string works but you are required to have one
model_name = config.app.get("ollama_model_name")
base_url = config.app.get("ollama_base_url", "")
if not base_url:
base_url = "http://localhost:11434/v1"
elif llm_provider == "openai":
api_key = config.app.get("openai_api_key")
model_name = config.app.get("openai_model_name")
base_url = config.app.get("openai_base_url", "")
if not base_url:
base_url = "https://api.openai.com/v1"
elif llm_provider == "oneapi":
api_key = config.app.get("oneapi_api_key")
model_name = config.app.get("oneapi_model_name")
base_url = config.app.get("oneapi_base_url", "")
elif llm_provider == "azure":
api_key = config.app.get("azure_api_key")
model_name = config.app.get("azure_model_name")
base_url = config.app.get("azure_base_url", "")
api_version = config.app.get("azure_api_version", "2024-02-15-preview")
elif llm_provider == "gemini":
api_key = config.app.get("gemini_api_key")
model_name = config.app.get("gemini_model_name")
base_url = "***"
elif llm_provider == "qwen":
api_key = config.app.get("qwen_api_key")
model_name = config.app.get("qwen_model_name")
base_url = "***"
elif llm_provider == "cloudflare":
api_key = config.app.get("cloudflare_api_key")
model_name = config.app.get("cloudflare_model_name")
account_id = config.app.get("cloudflare_account_id")
base_url = "***"
elif llm_provider == "deepseek":
api_key = config.app.get("deepseek_api_key")
model_name = config.app.get("deepseek_model_name")
base_url = config.app.get("deepseek_base_url")
if not base_url:
base_url = "https://api.deepseek.com"
elif llm_provider == "ernie":
api_key = config.app.get("ernie_api_key")
secret_key = config.app.get("ernie_secret_key")
base_url = config.app.get("ernie_base_url")
model_name = "***"
if not secret_key:
raise ValueError(
f"{llm_provider}: secret_key is not set, please set it in the config.toml file."
)
else: else:
raise ValueError( api_version = "" # for azure
"llm_provider is not set, please set it in the config.toml file." if llm_provider == "moonshot":
) api_key = config.app.get("moonshot_api_key")
model_name = config.app.get("moonshot_model_name")
base_url = "https://api.moonshot.cn/v1"
elif llm_provider == "ollama":
# api_key = config.app.get("openai_api_key")
api_key = "ollama" # any string works but you are required to have one
model_name = config.app.get("ollama_model_name")
base_url = config.app.get("ollama_base_url", "")
if not base_url:
base_url = "http://localhost:11434/v1"
elif llm_provider == "openai":
api_key = config.app.get("openai_api_key")
model_name = config.app.get("openai_model_name")
base_url = config.app.get("openai_base_url", "")
if not base_url:
base_url = "https://api.openai.com/v1"
elif llm_provider == "oneapi":
api_key = config.app.get("oneapi_api_key")
model_name = config.app.get("oneapi_model_name")
base_url = config.app.get("oneapi_base_url", "")
elif llm_provider == "azure":
api_key = config.app.get("azure_api_key")
model_name = config.app.get("azure_model_name")
base_url = config.app.get("azure_base_url", "")
api_version = config.app.get("azure_api_version", "2024-02-15-preview")
elif llm_provider == "gemini":
api_key = config.app.get("gemini_api_key")
model_name = config.app.get("gemini_model_name")
base_url = "***"
elif llm_provider == "qwen":
api_key = config.app.get("qwen_api_key")
model_name = config.app.get("qwen_model_name")
base_url = "***"
elif llm_provider == "cloudflare":
api_key = config.app.get("cloudflare_api_key")
model_name = config.app.get("cloudflare_model_name")
account_id = config.app.get("cloudflare_account_id")
base_url = "***"
elif llm_provider == "deepseek":
api_key = config.app.get("deepseek_api_key")
model_name = config.app.get("deepseek_model_name")
base_url = config.app.get("deepseek_base_url")
if not base_url:
base_url = "https://api.deepseek.com"
elif llm_provider == "ernie":
api_key = config.app.get("ernie_api_key")
secret_key = config.app.get("ernie_secret_key")
base_url = config.app.get("ernie_base_url")
model_name = "***"
if not secret_key:
raise ValueError(
f"{llm_provider}: secret_key is not set, please set it in the config.toml file."
)
else:
raise ValueError(
"llm_provider is not set, please set it in the config.toml file."
)
if not api_key: if not api_key:
raise ValueError( raise ValueError(
f"{llm_provider}: api_key is not set, please set it in the config.toml file." f"{llm_provider}: api_key is not set, please set it in the config.toml file."
) )
if not model_name: if not model_name:
raise ValueError( raise ValueError(
f"{llm_provider}: model_name is not set, please set it in the config.toml file." f"{llm_provider}: model_name is not set, please set it in the config.toml file."
) )
if not base_url: if not base_url:
raise ValueError( raise ValueError(
f"{llm_provider}: base_url is not set, please set it in the config.toml file." f"{llm_provider}: base_url is not set, please set it in the config.toml file."
) )
if llm_provider == "qwen": if llm_provider == "qwen":
import dashscope import dashscope
from dashscope.api_entities.dashscope_response import GenerationResponse from dashscope.api_entities.dashscope_response import GenerationResponse
dashscope.api_key = api_key dashscope.api_key = api_key
response = dashscope.Generation.call( response = dashscope.Generation.call(
model=model_name, messages=[{"role": "user", "content": prompt}]
)
if response:
if isinstance(response, GenerationResponse):
status_code = response.status_code
if status_code != 200:
raise Exception(
f'[{llm_provider}] returned an error response: "{response}"'
)
content = response["output"]["text"]
return content.replace("\n", "")
else:
raise Exception(
f'[{llm_provider}] returned an invalid response: "{response}"'
)
else:
raise Exception(f"[{llm_provider}] returned an empty response")
if llm_provider == "gemini":
import google.generativeai as genai
genai.configure(api_key=api_key, transport="rest")
generation_config = {
"temperature": 0.5,
"top_p": 1,
"top_k": 1,
"max_output_tokens": 2048,
}
safety_settings = [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_ONLY_HIGH",
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_ONLY_HIGH",
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_ONLY_HIGH",
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_ONLY_HIGH",
},
]
model = genai.GenerativeModel(
model_name=model_name,
generation_config=generation_config,
safety_settings=safety_settings,
)
try:
response = model.generate_content(prompt)
candidates = response.candidates
generated_text = candidates[0].content.parts[0].text
except (AttributeError, IndexError) as e:
print("Gemini Error:", e)
return generated_text
if llm_provider == "cloudflare":
import requests
response = requests.post(
f"https://api.cloudflare.com/client/v4/accounts/{account_id}/ai/run/{model_name}",
headers={"Authorization": f"Bearer {api_key}"},
json={
"messages": [
{"role": "system", "content": "You are a friendly assistant"},
{"role": "user", "content": prompt},
]
},
)
result = response.json()
logger.info(result)
return result["result"]["response"]
if llm_provider == "ernie":
import requests
params = {
"grant_type": "client_credentials",
"client_id": api_key,
"client_secret": secret_key,
}
access_token = (
requests.post("https://aip.baidubce.com/oauth/2.0/token", params=params)
.json()
.get("access_token")
)
url = f"{base_url}?access_token={access_token}"
payload = json.dumps(
{
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.5,
"top_p": 0.8,
"penalty_score": 1,
"disable_search": False,
"enable_citation": False,
"response_format": "text",
}
)
headers = {"Content-Type": "application/json"}
response = requests.request(
"POST", url, headers=headers, data=payload
).json()
return response.get("result")
if llm_provider == "azure":
client = AzureOpenAI(
api_key=api_key,
api_version=api_version,
azure_endpoint=base_url,
)
else:
client = OpenAI(
api_key=api_key,
base_url=base_url,
)
response = client.chat.completions.create(
model=model_name, messages=[{"role": "user", "content": prompt}] model=model_name, messages=[{"role": "user", "content": prompt}]
) )
if response: if response:
if isinstance(response, GenerationResponse): if isinstance(response, ChatCompletion):
status_code = response.status_code content = response.choices[0].message.content
if status_code != 200:
raise Exception(
f'[{llm_provider}] returned an error response: "{response}"'
)
content = response["output"]["text"]
return content.replace("\n", "")
else: else:
raise Exception( raise Exception(
f'[{llm_provider}] returned an invalid response: "{response}"' f'[{llm_provider}] returned an invalid response: "{response}", please check your network '
f"connection and try again."
) )
else:
raise Exception(f"[{llm_provider}] returned an empty response")
if llm_provider == "gemini":
import google.generativeai as genai
genai.configure(api_key=api_key, transport="rest")
generation_config = {
"temperature": 0.5,
"top_p": 1,
"top_k": 1,
"max_output_tokens": 2048,
}
safety_settings = [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_ONLY_HIGH",
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_ONLY_HIGH",
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_ONLY_HIGH",
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_ONLY_HIGH",
},
]
model = genai.GenerativeModel(
model_name=model_name,
generation_config=generation_config,
safety_settings=safety_settings,
)
try:
response = model.generate_content(prompt)
candidates = response.candidates
generated_text = candidates[0].content.parts[0].text
except (AttributeError, IndexError) as e:
print("Gemini Error:", e)
return generated_text
if llm_provider == "cloudflare":
import requests
response = requests.post(
f"https://api.cloudflare.com/client/v4/accounts/{account_id}/ai/run/{model_name}",
headers={"Authorization": f"Bearer {api_key}"},
json={
"messages": [
{"role": "system", "content": "You are a friendly assistant"},
{"role": "user", "content": prompt},
]
},
)
result = response.json()
logger.info(result)
return result["result"]["response"]
if llm_provider == "ernie":
import requests
params = {
"grant_type": "client_credentials",
"client_id": api_key,
"client_secret": secret_key,
}
access_token = (
requests.post("https://aip.baidubce.com/oauth/2.0/token", params=params)
.json()
.get("access_token")
)
url = f"{base_url}?access_token={access_token}"
payload = json.dumps(
{
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.5,
"top_p": 0.8,
"penalty_score": 1,
"disable_search": False,
"enable_citation": False,
"response_format": "text",
}
)
headers = {"Content-Type": "application/json"}
response = requests.request(
"POST", url, headers=headers, data=payload
).json()
return response.get("result")
if llm_provider == "azure":
client = AzureOpenAI(
api_key=api_key,
api_version=api_version,
azure_endpoint=base_url,
)
else:
client = OpenAI(
api_key=api_key,
base_url=base_url,
)
response = client.chat.completions.create(
model=model_name, messages=[{"role": "user", "content": prompt}]
)
if response:
if isinstance(response, ChatCompletion):
content = response.choices[0].message.content
else: else:
raise Exception( raise Exception(
f'[{llm_provider}] returned an invalid response: "{response}", please check your network ' f"[{llm_provider}] returned an empty response, please check your network connection and try again."
f"connection and try again."
) )
else:
raise Exception(
f"[{llm_provider}] returned an empty response, please check your network connection and try again."
)
return content.replace("\n", "") return content.replace("\n", "")
except Exception as e:
return f"Error: {str(e)}"
def generate_script( def generate_script(
@@ -319,8 +321,10 @@ Generate a script for a video, depending on the subject of the video.
if i < _max_retries: if i < _max_retries:
logger.warning(f"failed to generate video script, trying again... {i + 1}") logger.warning(f"failed to generate video script, trying again... {i + 1}")
if "Error: " in final_script:
logger.success(f"completed: \n{final_script}") logger.error(f"failed to generate video script: {final_script}")
else:
logger.success(f"completed: \n{final_script}")
return final_script.strip() return final_script.strip()
@@ -358,6 +362,9 @@ Please note that you must use English for generating video search terms; Chinese
for i in range(_max_retries): for i in range(_max_retries):
try: try:
response = _generate_response(prompt) response = _generate_response(prompt)
if "Error: " in response:
logger.error(f"failed to generate video script: {response}")
return response
search_terms = json.loads(response) search_terms = json.loads(response)
if not isinstance(search_terms, list) or not all( if not isinstance(search_terms, list) or not all(
isinstance(term, str) for term in search_terms isinstance(term, str) for term in search_terms

View File

@@ -214,7 +214,7 @@ def start(task_id, params: VideoParams, stop_at: str = "video"):
# 1. Generate script # 1. Generate script
video_script = generate_script(task_id, params) video_script = generate_script(task_id, params)
if not video_script: if not video_script or "Error: " in video_script:
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED) sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
return return

View File

@@ -449,8 +449,12 @@ with left_panel:
selected_index = st.selectbox( selected_index = st.selectbox(
tr("Script Language"), tr("Script Language"),
index=0, index=0,
options=range(len(video_languages)), # 使用索引作为内部选项值 options=range(
format_func=lambda x: video_languages[x][0], # 显示给用户的是标签 len(video_languages)
), # Use the index as the internal option value
format_func=lambda x: video_languages[x][
0
], # The label is displayed to the user
) )
params.video_language = video_languages[selected_index][1] params.video_language = video_languages[selected_index][1]
@@ -462,9 +466,13 @@ with left_panel:
video_subject=params.video_subject, language=params.video_language video_subject=params.video_subject, language=params.video_language
) )
terms = llm.generate_terms(params.video_subject, script) terms = llm.generate_terms(params.video_subject, script)
st.session_state["video_script"] = script if "Error: " in script:
st.session_state["video_terms"] = ", ".join(terms) st.error(tr(script))
elif "Error: " in terms:
st.error(tr(terms))
else:
st.session_state["video_script"] = script
st.session_state["video_terms"] = ", ".join(terms)
params.video_script = st.text_area( params.video_script = st.text_area(
tr("Video Script"), value=st.session_state["video_script"], height=280 tr("Video Script"), value=st.session_state["video_script"], height=280
) )
@@ -475,7 +483,10 @@ with left_panel:
with st.spinner(tr("Generating Video Keywords")): with st.spinner(tr("Generating Video Keywords")):
terms = llm.generate_terms(params.video_subject, params.video_script) terms = llm.generate_terms(params.video_subject, params.video_script)
st.session_state["video_terms"] = ", ".join(terms) if "Error: " in terms:
st.error(tr(terms))
else:
st.session_state["video_terms"] = ", ".join(terms)
params.video_terms = st.text_area( params.video_terms = st.text_area(
tr("Video Keywords"), value=st.session_state["video_terms"] tr("Video Keywords"), value=st.session_state["video_terms"]
@@ -522,8 +533,12 @@ with middle_panel:
selected_index = st.selectbox( selected_index = st.selectbox(
tr("Video Concat Mode"), tr("Video Concat Mode"),
index=1, index=1,
options=range(len(video_concat_modes)), # 使用索引作为内部选项值 options=range(
format_func=lambda x: video_concat_modes[x][0], # 显示给用户的是标签 len(video_concat_modes)
), # Use the index as the internal option value
format_func=lambda x: video_concat_modes[x][
0
], # The label is displayed to the user
) )
params.video_concat_mode = VideoConcatMode( params.video_concat_mode = VideoConcatMode(
video_concat_modes[selected_index][1] video_concat_modes[selected_index][1]
@@ -535,8 +550,12 @@ with middle_panel:
] ]
selected_index = st.selectbox( selected_index = st.selectbox(
tr("Video Ratio"), tr("Video Ratio"),
options=range(len(video_aspect_ratios)), # 使用索引作为内部选项值 options=range(
format_func=lambda x: video_aspect_ratios[x][0], # 显示给用户的是标签 len(video_aspect_ratios)
), # Use the index as the internal option value
format_func=lambda x: video_aspect_ratios[x][
0
], # The label is displayed to the user
) )
params.video_aspect = VideoAspect(video_aspect_ratios[selected_index][1]) params.video_aspect = VideoAspect(video_aspect_ratios[selected_index][1])
@@ -648,13 +667,17 @@ with middle_panel:
selected_index = st.selectbox( selected_index = st.selectbox(
tr("Background Music"), tr("Background Music"),
index=1, index=1,
options=range(len(bgm_options)), # 使用索引作为内部选项值 options=range(
format_func=lambda x: bgm_options[x][0], # 显示给用户的是标签 len(bgm_options)
), # Use the index as the internal option value
format_func=lambda x: bgm_options[x][
0
], # The label is displayed to the user
) )
# 获取选择的背景音乐类型 # Get the selected background music type
params.bgm_type = bgm_options[selected_index][1] params.bgm_type = bgm_options[selected_index][1]
# 根据选择显示或隐藏组件 # Show or hide components based on the selection
if params.bgm_type == "custom": if params.bgm_type == "custom":
custom_bgm_file = st.text_input(tr("Custom Background Music File")) custom_bgm_file = st.text_input(tr("Custom Background Music File"))
if custom_bgm_file and os.path.exists(custom_bgm_file): if custom_bgm_file and os.path.exists(custom_bgm_file):
@@ -733,15 +756,6 @@ if start_button:
scroll_to_bottom() scroll_to_bottom()
st.stop() st.stop()
if (
llm_provider != "g4f"
and llm_provider != "ollama"
and not config.app.get(f"{llm_provider}_api_key", "")
):
st.error(tr("Please Enter the LLM API Key"))
scroll_to_bottom()
st.stop()
if params.video_source not in ["pexels", "pixabay", "local"]: if params.video_source not in ["pexels", "pixabay", "local"]:
st.error(tr("Please Select a Valid Video Source")) st.error(tr("Please Select a Valid Video Source"))
scroll_to_bottom() scroll_to_bottom()