🐛 fix: fix the LLM logic
This commit is contained in:
@@ -3,6 +3,7 @@ import logging
|
|||||||
import re
|
import re
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
import g4f
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from openai import AzureOpenAI, OpenAI
|
from openai import AzureOpenAI, OpenAI
|
||||||
from openai.types.chat import ChatCompletion
|
from openai.types.chat import ChatCompletion
|
||||||
@@ -13,243 +14,244 @@ _max_retries = 5
|
|||||||
|
|
||||||
|
|
||||||
def _generate_response(prompt: str) -> str:
|
def _generate_response(prompt: str) -> str:
|
||||||
content = ""
|
try:
|
||||||
llm_provider = config.app.get("llm_provider", "openai")
|
content = ""
|
||||||
logger.info(f"llm provider: {llm_provider}")
|
llm_provider = config.app.get("llm_provider", "openai")
|
||||||
if llm_provider == "g4f":
|
logger.info(f"llm provider: {llm_provider}")
|
||||||
model_name = config.app.get("g4f_model_name", "")
|
if llm_provider == "g4f":
|
||||||
if not model_name:
|
model_name = config.app.get("g4f_model_name", "")
|
||||||
model_name = "gpt-3.5-turbo-16k-0613"
|
if not model_name:
|
||||||
import g4f
|
model_name = "gpt-3.5-turbo-16k-0613"
|
||||||
|
content = g4f.ChatCompletion.create(
|
||||||
content = g4f.ChatCompletion.create(
|
model=model_name,
|
||||||
model=model_name,
|
messages=[{"role": "user", "content": prompt}],
|
||||||
messages=[{"role": "user", "content": prompt}],
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
api_version = "" # for azure
|
|
||||||
if llm_provider == "moonshot":
|
|
||||||
api_key = config.app.get("moonshot_api_key")
|
|
||||||
model_name = config.app.get("moonshot_model_name")
|
|
||||||
base_url = "https://api.moonshot.cn/v1"
|
|
||||||
elif llm_provider == "ollama":
|
|
||||||
# api_key = config.app.get("openai_api_key")
|
|
||||||
api_key = "ollama" # any string works but you are required to have one
|
|
||||||
model_name = config.app.get("ollama_model_name")
|
|
||||||
base_url = config.app.get("ollama_base_url", "")
|
|
||||||
if not base_url:
|
|
||||||
base_url = "http://localhost:11434/v1"
|
|
||||||
elif llm_provider == "openai":
|
|
||||||
api_key = config.app.get("openai_api_key")
|
|
||||||
model_name = config.app.get("openai_model_name")
|
|
||||||
base_url = config.app.get("openai_base_url", "")
|
|
||||||
if not base_url:
|
|
||||||
base_url = "https://api.openai.com/v1"
|
|
||||||
elif llm_provider == "oneapi":
|
|
||||||
api_key = config.app.get("oneapi_api_key")
|
|
||||||
model_name = config.app.get("oneapi_model_name")
|
|
||||||
base_url = config.app.get("oneapi_base_url", "")
|
|
||||||
elif llm_provider == "azure":
|
|
||||||
api_key = config.app.get("azure_api_key")
|
|
||||||
model_name = config.app.get("azure_model_name")
|
|
||||||
base_url = config.app.get("azure_base_url", "")
|
|
||||||
api_version = config.app.get("azure_api_version", "2024-02-15-preview")
|
|
||||||
elif llm_provider == "gemini":
|
|
||||||
api_key = config.app.get("gemini_api_key")
|
|
||||||
model_name = config.app.get("gemini_model_name")
|
|
||||||
base_url = "***"
|
|
||||||
elif llm_provider == "qwen":
|
|
||||||
api_key = config.app.get("qwen_api_key")
|
|
||||||
model_name = config.app.get("qwen_model_name")
|
|
||||||
base_url = "***"
|
|
||||||
elif llm_provider == "cloudflare":
|
|
||||||
api_key = config.app.get("cloudflare_api_key")
|
|
||||||
model_name = config.app.get("cloudflare_model_name")
|
|
||||||
account_id = config.app.get("cloudflare_account_id")
|
|
||||||
base_url = "***"
|
|
||||||
elif llm_provider == "deepseek":
|
|
||||||
api_key = config.app.get("deepseek_api_key")
|
|
||||||
model_name = config.app.get("deepseek_model_name")
|
|
||||||
base_url = config.app.get("deepseek_base_url")
|
|
||||||
if not base_url:
|
|
||||||
base_url = "https://api.deepseek.com"
|
|
||||||
elif llm_provider == "ernie":
|
|
||||||
api_key = config.app.get("ernie_api_key")
|
|
||||||
secret_key = config.app.get("ernie_secret_key")
|
|
||||||
base_url = config.app.get("ernie_base_url")
|
|
||||||
model_name = "***"
|
|
||||||
if not secret_key:
|
|
||||||
raise ValueError(
|
|
||||||
f"{llm_provider}: secret_key is not set, please set it in the config.toml file."
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
api_version = "" # for azure
|
||||||
"llm_provider is not set, please set it in the config.toml file."
|
if llm_provider == "moonshot":
|
||||||
)
|
api_key = config.app.get("moonshot_api_key")
|
||||||
|
model_name = config.app.get("moonshot_model_name")
|
||||||
|
base_url = "https://api.moonshot.cn/v1"
|
||||||
|
elif llm_provider == "ollama":
|
||||||
|
# api_key = config.app.get("openai_api_key")
|
||||||
|
api_key = "ollama" # any string works but you are required to have one
|
||||||
|
model_name = config.app.get("ollama_model_name")
|
||||||
|
base_url = config.app.get("ollama_base_url", "")
|
||||||
|
if not base_url:
|
||||||
|
base_url = "http://localhost:11434/v1"
|
||||||
|
elif llm_provider == "openai":
|
||||||
|
api_key = config.app.get("openai_api_key")
|
||||||
|
model_name = config.app.get("openai_model_name")
|
||||||
|
base_url = config.app.get("openai_base_url", "")
|
||||||
|
if not base_url:
|
||||||
|
base_url = "https://api.openai.com/v1"
|
||||||
|
elif llm_provider == "oneapi":
|
||||||
|
api_key = config.app.get("oneapi_api_key")
|
||||||
|
model_name = config.app.get("oneapi_model_name")
|
||||||
|
base_url = config.app.get("oneapi_base_url", "")
|
||||||
|
elif llm_provider == "azure":
|
||||||
|
api_key = config.app.get("azure_api_key")
|
||||||
|
model_name = config.app.get("azure_model_name")
|
||||||
|
base_url = config.app.get("azure_base_url", "")
|
||||||
|
api_version = config.app.get("azure_api_version", "2024-02-15-preview")
|
||||||
|
elif llm_provider == "gemini":
|
||||||
|
api_key = config.app.get("gemini_api_key")
|
||||||
|
model_name = config.app.get("gemini_model_name")
|
||||||
|
base_url = "***"
|
||||||
|
elif llm_provider == "qwen":
|
||||||
|
api_key = config.app.get("qwen_api_key")
|
||||||
|
model_name = config.app.get("qwen_model_name")
|
||||||
|
base_url = "***"
|
||||||
|
elif llm_provider == "cloudflare":
|
||||||
|
api_key = config.app.get("cloudflare_api_key")
|
||||||
|
model_name = config.app.get("cloudflare_model_name")
|
||||||
|
account_id = config.app.get("cloudflare_account_id")
|
||||||
|
base_url = "***"
|
||||||
|
elif llm_provider == "deepseek":
|
||||||
|
api_key = config.app.get("deepseek_api_key")
|
||||||
|
model_name = config.app.get("deepseek_model_name")
|
||||||
|
base_url = config.app.get("deepseek_base_url")
|
||||||
|
if not base_url:
|
||||||
|
base_url = "https://api.deepseek.com"
|
||||||
|
elif llm_provider == "ernie":
|
||||||
|
api_key = config.app.get("ernie_api_key")
|
||||||
|
secret_key = config.app.get("ernie_secret_key")
|
||||||
|
base_url = config.app.get("ernie_base_url")
|
||||||
|
model_name = "***"
|
||||||
|
if not secret_key:
|
||||||
|
raise ValueError(
|
||||||
|
f"{llm_provider}: secret_key is not set, please set it in the config.toml file."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"llm_provider is not set, please set it in the config.toml file."
|
||||||
|
)
|
||||||
|
|
||||||
if not api_key:
|
if not api_key:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"{llm_provider}: api_key is not set, please set it in the config.toml file."
|
f"{llm_provider}: api_key is not set, please set it in the config.toml file."
|
||||||
)
|
)
|
||||||
if not model_name:
|
if not model_name:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"{llm_provider}: model_name is not set, please set it in the config.toml file."
|
f"{llm_provider}: model_name is not set, please set it in the config.toml file."
|
||||||
)
|
)
|
||||||
if not base_url:
|
if not base_url:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"{llm_provider}: base_url is not set, please set it in the config.toml file."
|
f"{llm_provider}: base_url is not set, please set it in the config.toml file."
|
||||||
)
|
)
|
||||||
|
|
||||||
if llm_provider == "qwen":
|
if llm_provider == "qwen":
|
||||||
import dashscope
|
import dashscope
|
||||||
from dashscope.api_entities.dashscope_response import GenerationResponse
|
from dashscope.api_entities.dashscope_response import GenerationResponse
|
||||||
|
|
||||||
dashscope.api_key = api_key
|
dashscope.api_key = api_key
|
||||||
response = dashscope.Generation.call(
|
response = dashscope.Generation.call(
|
||||||
|
model=model_name, messages=[{"role": "user", "content": prompt}]
|
||||||
|
)
|
||||||
|
if response:
|
||||||
|
if isinstance(response, GenerationResponse):
|
||||||
|
status_code = response.status_code
|
||||||
|
if status_code != 200:
|
||||||
|
raise Exception(
|
||||||
|
f'[{llm_provider}] returned an error response: "{response}"'
|
||||||
|
)
|
||||||
|
|
||||||
|
content = response["output"]["text"]
|
||||||
|
return content.replace("\n", "")
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
f'[{llm_provider}] returned an invalid response: "{response}"'
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception(f"[{llm_provider}] returned an empty response")
|
||||||
|
|
||||||
|
if llm_provider == "gemini":
|
||||||
|
import google.generativeai as genai
|
||||||
|
|
||||||
|
genai.configure(api_key=api_key, transport="rest")
|
||||||
|
|
||||||
|
generation_config = {
|
||||||
|
"temperature": 0.5,
|
||||||
|
"top_p": 1,
|
||||||
|
"top_k": 1,
|
||||||
|
"max_output_tokens": 2048,
|
||||||
|
}
|
||||||
|
|
||||||
|
safety_settings = [
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_HARASSMENT",
|
||||||
|
"threshold": "BLOCK_ONLY_HIGH",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||||
|
"threshold": "BLOCK_ONLY_HIGH",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||||
|
"threshold": "BLOCK_ONLY_HIGH",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||||
|
"threshold": "BLOCK_ONLY_HIGH",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
model = genai.GenerativeModel(
|
||||||
|
model_name=model_name,
|
||||||
|
generation_config=generation_config,
|
||||||
|
safety_settings=safety_settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = model.generate_content(prompt)
|
||||||
|
candidates = response.candidates
|
||||||
|
generated_text = candidates[0].content.parts[0].text
|
||||||
|
except (AttributeError, IndexError) as e:
|
||||||
|
print("Gemini Error:", e)
|
||||||
|
|
||||||
|
return generated_text
|
||||||
|
|
||||||
|
if llm_provider == "cloudflare":
|
||||||
|
import requests
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
f"https://api.cloudflare.com/client/v4/accounts/{account_id}/ai/run/{model_name}",
|
||||||
|
headers={"Authorization": f"Bearer {api_key}"},
|
||||||
|
json={
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "You are a friendly assistant"},
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
result = response.json()
|
||||||
|
logger.info(result)
|
||||||
|
return result["result"]["response"]
|
||||||
|
|
||||||
|
if llm_provider == "ernie":
|
||||||
|
import requests
|
||||||
|
|
||||||
|
params = {
|
||||||
|
"grant_type": "client_credentials",
|
||||||
|
"client_id": api_key,
|
||||||
|
"client_secret": secret_key,
|
||||||
|
}
|
||||||
|
access_token = (
|
||||||
|
requests.post("https://aip.baidubce.com/oauth/2.0/token", params=params)
|
||||||
|
.json()
|
||||||
|
.get("access_token")
|
||||||
|
)
|
||||||
|
url = f"{base_url}?access_token={access_token}"
|
||||||
|
|
||||||
|
payload = json.dumps(
|
||||||
|
{
|
||||||
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
|
"temperature": 0.5,
|
||||||
|
"top_p": 0.8,
|
||||||
|
"penalty_score": 1,
|
||||||
|
"disable_search": False,
|
||||||
|
"enable_citation": False,
|
||||||
|
"response_format": "text",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
|
||||||
|
response = requests.request(
|
||||||
|
"POST", url, headers=headers, data=payload
|
||||||
|
).json()
|
||||||
|
return response.get("result")
|
||||||
|
|
||||||
|
if llm_provider == "azure":
|
||||||
|
client = AzureOpenAI(
|
||||||
|
api_key=api_key,
|
||||||
|
api_version=api_version,
|
||||||
|
azure_endpoint=base_url,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
client = OpenAI(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
model=model_name, messages=[{"role": "user", "content": prompt}]
|
model=model_name, messages=[{"role": "user", "content": prompt}]
|
||||||
)
|
)
|
||||||
if response:
|
if response:
|
||||||
if isinstance(response, GenerationResponse):
|
if isinstance(response, ChatCompletion):
|
||||||
status_code = response.status_code
|
content = response.choices[0].message.content
|
||||||
if status_code != 200:
|
|
||||||
raise Exception(
|
|
||||||
f'[{llm_provider}] returned an error response: "{response}"'
|
|
||||||
)
|
|
||||||
|
|
||||||
content = response["output"]["text"]
|
|
||||||
return content.replace("\n", "")
|
|
||||||
else:
|
else:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f'[{llm_provider}] returned an invalid response: "{response}"'
|
f'[{llm_provider}] returned an invalid response: "{response}", please check your network '
|
||||||
|
f"connection and try again."
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
raise Exception(f"[{llm_provider}] returned an empty response")
|
|
||||||
|
|
||||||
if llm_provider == "gemini":
|
|
||||||
import google.generativeai as genai
|
|
||||||
|
|
||||||
genai.configure(api_key=api_key, transport="rest")
|
|
||||||
|
|
||||||
generation_config = {
|
|
||||||
"temperature": 0.5,
|
|
||||||
"top_p": 1,
|
|
||||||
"top_k": 1,
|
|
||||||
"max_output_tokens": 2048,
|
|
||||||
}
|
|
||||||
|
|
||||||
safety_settings = [
|
|
||||||
{
|
|
||||||
"category": "HARM_CATEGORY_HARASSMENT",
|
|
||||||
"threshold": "BLOCK_ONLY_HIGH",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"category": "HARM_CATEGORY_HATE_SPEECH",
|
|
||||||
"threshold": "BLOCK_ONLY_HIGH",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
|
||||||
"threshold": "BLOCK_ONLY_HIGH",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
|
||||||
"threshold": "BLOCK_ONLY_HIGH",
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
model = genai.GenerativeModel(
|
|
||||||
model_name=model_name,
|
|
||||||
generation_config=generation_config,
|
|
||||||
safety_settings=safety_settings,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = model.generate_content(prompt)
|
|
||||||
candidates = response.candidates
|
|
||||||
generated_text = candidates[0].content.parts[0].text
|
|
||||||
except (AttributeError, IndexError) as e:
|
|
||||||
print("Gemini Error:", e)
|
|
||||||
|
|
||||||
return generated_text
|
|
||||||
|
|
||||||
if llm_provider == "cloudflare":
|
|
||||||
import requests
|
|
||||||
|
|
||||||
response = requests.post(
|
|
||||||
f"https://api.cloudflare.com/client/v4/accounts/{account_id}/ai/run/{model_name}",
|
|
||||||
headers={"Authorization": f"Bearer {api_key}"},
|
|
||||||
json={
|
|
||||||
"messages": [
|
|
||||||
{"role": "system", "content": "You are a friendly assistant"},
|
|
||||||
{"role": "user", "content": prompt},
|
|
||||||
]
|
|
||||||
},
|
|
||||||
)
|
|
||||||
result = response.json()
|
|
||||||
logger.info(result)
|
|
||||||
return result["result"]["response"]
|
|
||||||
|
|
||||||
if llm_provider == "ernie":
|
|
||||||
import requests
|
|
||||||
|
|
||||||
params = {
|
|
||||||
"grant_type": "client_credentials",
|
|
||||||
"client_id": api_key,
|
|
||||||
"client_secret": secret_key,
|
|
||||||
}
|
|
||||||
access_token = (
|
|
||||||
requests.post("https://aip.baidubce.com/oauth/2.0/token", params=params)
|
|
||||||
.json()
|
|
||||||
.get("access_token")
|
|
||||||
)
|
|
||||||
url = f"{base_url}?access_token={access_token}"
|
|
||||||
|
|
||||||
payload = json.dumps(
|
|
||||||
{
|
|
||||||
"messages": [{"role": "user", "content": prompt}],
|
|
||||||
"temperature": 0.5,
|
|
||||||
"top_p": 0.8,
|
|
||||||
"penalty_score": 1,
|
|
||||||
"disable_search": False,
|
|
||||||
"enable_citation": False,
|
|
||||||
"response_format": "text",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
headers = {"Content-Type": "application/json"}
|
|
||||||
|
|
||||||
response = requests.request(
|
|
||||||
"POST", url, headers=headers, data=payload
|
|
||||||
).json()
|
|
||||||
return response.get("result")
|
|
||||||
|
|
||||||
if llm_provider == "azure":
|
|
||||||
client = AzureOpenAI(
|
|
||||||
api_key=api_key,
|
|
||||||
api_version=api_version,
|
|
||||||
azure_endpoint=base_url,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
client = OpenAI(
|
|
||||||
api_key=api_key,
|
|
||||||
base_url=base_url,
|
|
||||||
)
|
|
||||||
|
|
||||||
response = client.chat.completions.create(
|
|
||||||
model=model_name, messages=[{"role": "user", "content": prompt}]
|
|
||||||
)
|
|
||||||
if response:
|
|
||||||
if isinstance(response, ChatCompletion):
|
|
||||||
content = response.choices[0].message.content
|
|
||||||
else:
|
else:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f'[{llm_provider}] returned an invalid response: "{response}", please check your network '
|
f"[{llm_provider}] returned an empty response, please check your network connection and try again."
|
||||||
f"connection and try again."
|
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
raise Exception(
|
|
||||||
f"[{llm_provider}] returned an empty response, please check your network connection and try again."
|
|
||||||
)
|
|
||||||
|
|
||||||
return content.replace("\n", "")
|
return content.replace("\n", "")
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
def generate_script(
|
def generate_script(
|
||||||
@@ -319,8 +321,10 @@ Generate a script for a video, depending on the subject of the video.
|
|||||||
|
|
||||||
if i < _max_retries:
|
if i < _max_retries:
|
||||||
logger.warning(f"failed to generate video script, trying again... {i + 1}")
|
logger.warning(f"failed to generate video script, trying again... {i + 1}")
|
||||||
|
if "Error: " in final_script:
|
||||||
logger.success(f"completed: \n{final_script}")
|
logger.error(f"failed to generate video script: {final_script}")
|
||||||
|
else:
|
||||||
|
logger.success(f"completed: \n{final_script}")
|
||||||
return final_script.strip()
|
return final_script.strip()
|
||||||
|
|
||||||
|
|
||||||
@@ -358,6 +362,9 @@ Please note that you must use English for generating video search terms; Chinese
|
|||||||
for i in range(_max_retries):
|
for i in range(_max_retries):
|
||||||
try:
|
try:
|
||||||
response = _generate_response(prompt)
|
response = _generate_response(prompt)
|
||||||
|
if "Error: " in response:
|
||||||
|
logger.error(f"failed to generate video script: {response}")
|
||||||
|
return response
|
||||||
search_terms = json.loads(response)
|
search_terms = json.loads(response)
|
||||||
if not isinstance(search_terms, list) or not all(
|
if not isinstance(search_terms, list) or not all(
|
||||||
isinstance(term, str) for term in search_terms
|
isinstance(term, str) for term in search_terms
|
||||||
|
|||||||
@@ -214,7 +214,7 @@ def start(task_id, params: VideoParams, stop_at: str = "video"):
|
|||||||
|
|
||||||
# 1. Generate script
|
# 1. Generate script
|
||||||
video_script = generate_script(task_id, params)
|
video_script = generate_script(task_id, params)
|
||||||
if not video_script:
|
if not video_script or "Error: " in video_script:
|
||||||
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
|
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
@@ -449,8 +449,12 @@ with left_panel:
|
|||||||
selected_index = st.selectbox(
|
selected_index = st.selectbox(
|
||||||
tr("Script Language"),
|
tr("Script Language"),
|
||||||
index=0,
|
index=0,
|
||||||
options=range(len(video_languages)), # 使用索引作为内部选项值
|
options=range(
|
||||||
format_func=lambda x: video_languages[x][0], # 显示给用户的是标签
|
len(video_languages)
|
||||||
|
), # Use the index as the internal option value
|
||||||
|
format_func=lambda x: video_languages[x][
|
||||||
|
0
|
||||||
|
], # The label is displayed to the user
|
||||||
)
|
)
|
||||||
params.video_language = video_languages[selected_index][1]
|
params.video_language = video_languages[selected_index][1]
|
||||||
|
|
||||||
@@ -462,9 +466,13 @@ with left_panel:
|
|||||||
video_subject=params.video_subject, language=params.video_language
|
video_subject=params.video_subject, language=params.video_language
|
||||||
)
|
)
|
||||||
terms = llm.generate_terms(params.video_subject, script)
|
terms = llm.generate_terms(params.video_subject, script)
|
||||||
st.session_state["video_script"] = script
|
if "Error: " in script:
|
||||||
st.session_state["video_terms"] = ", ".join(terms)
|
st.error(tr(script))
|
||||||
|
elif "Error: " in terms:
|
||||||
|
st.error(tr(terms))
|
||||||
|
else:
|
||||||
|
st.session_state["video_script"] = script
|
||||||
|
st.session_state["video_terms"] = ", ".join(terms)
|
||||||
params.video_script = st.text_area(
|
params.video_script = st.text_area(
|
||||||
tr("Video Script"), value=st.session_state["video_script"], height=280
|
tr("Video Script"), value=st.session_state["video_script"], height=280
|
||||||
)
|
)
|
||||||
@@ -475,7 +483,10 @@ with left_panel:
|
|||||||
|
|
||||||
with st.spinner(tr("Generating Video Keywords")):
|
with st.spinner(tr("Generating Video Keywords")):
|
||||||
terms = llm.generate_terms(params.video_subject, params.video_script)
|
terms = llm.generate_terms(params.video_subject, params.video_script)
|
||||||
st.session_state["video_terms"] = ", ".join(terms)
|
if "Error: " in terms:
|
||||||
|
st.error(tr(terms))
|
||||||
|
else:
|
||||||
|
st.session_state["video_terms"] = ", ".join(terms)
|
||||||
|
|
||||||
params.video_terms = st.text_area(
|
params.video_terms = st.text_area(
|
||||||
tr("Video Keywords"), value=st.session_state["video_terms"]
|
tr("Video Keywords"), value=st.session_state["video_terms"]
|
||||||
@@ -522,8 +533,12 @@ with middle_panel:
|
|||||||
selected_index = st.selectbox(
|
selected_index = st.selectbox(
|
||||||
tr("Video Concat Mode"),
|
tr("Video Concat Mode"),
|
||||||
index=1,
|
index=1,
|
||||||
options=range(len(video_concat_modes)), # 使用索引作为内部选项值
|
options=range(
|
||||||
format_func=lambda x: video_concat_modes[x][0], # 显示给用户的是标签
|
len(video_concat_modes)
|
||||||
|
), # Use the index as the internal option value
|
||||||
|
format_func=lambda x: video_concat_modes[x][
|
||||||
|
0
|
||||||
|
], # The label is displayed to the user
|
||||||
)
|
)
|
||||||
params.video_concat_mode = VideoConcatMode(
|
params.video_concat_mode = VideoConcatMode(
|
||||||
video_concat_modes[selected_index][1]
|
video_concat_modes[selected_index][1]
|
||||||
@@ -535,8 +550,12 @@ with middle_panel:
|
|||||||
]
|
]
|
||||||
selected_index = st.selectbox(
|
selected_index = st.selectbox(
|
||||||
tr("Video Ratio"),
|
tr("Video Ratio"),
|
||||||
options=range(len(video_aspect_ratios)), # 使用索引作为内部选项值
|
options=range(
|
||||||
format_func=lambda x: video_aspect_ratios[x][0], # 显示给用户的是标签
|
len(video_aspect_ratios)
|
||||||
|
), # Use the index as the internal option value
|
||||||
|
format_func=lambda x: video_aspect_ratios[x][
|
||||||
|
0
|
||||||
|
], # The label is displayed to the user
|
||||||
)
|
)
|
||||||
params.video_aspect = VideoAspect(video_aspect_ratios[selected_index][1])
|
params.video_aspect = VideoAspect(video_aspect_ratios[selected_index][1])
|
||||||
|
|
||||||
@@ -648,13 +667,17 @@ with middle_panel:
|
|||||||
selected_index = st.selectbox(
|
selected_index = st.selectbox(
|
||||||
tr("Background Music"),
|
tr("Background Music"),
|
||||||
index=1,
|
index=1,
|
||||||
options=range(len(bgm_options)), # 使用索引作为内部选项值
|
options=range(
|
||||||
format_func=lambda x: bgm_options[x][0], # 显示给用户的是标签
|
len(bgm_options)
|
||||||
|
), # Use the index as the internal option value
|
||||||
|
format_func=lambda x: bgm_options[x][
|
||||||
|
0
|
||||||
|
], # The label is displayed to the user
|
||||||
)
|
)
|
||||||
# 获取选择的背景音乐类型
|
# Get the selected background music type
|
||||||
params.bgm_type = bgm_options[selected_index][1]
|
params.bgm_type = bgm_options[selected_index][1]
|
||||||
|
|
||||||
# 根据选择显示或隐藏组件
|
# Show or hide components based on the selection
|
||||||
if params.bgm_type == "custom":
|
if params.bgm_type == "custom":
|
||||||
custom_bgm_file = st.text_input(tr("Custom Background Music File"))
|
custom_bgm_file = st.text_input(tr("Custom Background Music File"))
|
||||||
if custom_bgm_file and os.path.exists(custom_bgm_file):
|
if custom_bgm_file and os.path.exists(custom_bgm_file):
|
||||||
@@ -733,15 +756,6 @@ if start_button:
|
|||||||
scroll_to_bottom()
|
scroll_to_bottom()
|
||||||
st.stop()
|
st.stop()
|
||||||
|
|
||||||
if (
|
|
||||||
llm_provider != "g4f"
|
|
||||||
and llm_provider != "ollama"
|
|
||||||
and not config.app.get(f"{llm_provider}_api_key", "")
|
|
||||||
):
|
|
||||||
st.error(tr("Please Enter the LLM API Key"))
|
|
||||||
scroll_to_bottom()
|
|
||||||
st.stop()
|
|
||||||
|
|
||||||
if params.video_source not in ["pexels", "pixabay", "local"]:
|
if params.video_source not in ["pexels", "pixabay", "local"]:
|
||||||
st.error(tr("Please Select a Valid Video Source"))
|
st.error(tr("Please Select a Valid Video Source"))
|
||||||
scroll_to_bottom()
|
scroll_to_bottom()
|
||||||
|
|||||||
Reference in New Issue
Block a user