support baidu ERNIE llm
This commit is contained in:
@@ -72,6 +72,13 @@ def _generate_response(prompt: str) -> str:
|
||||
base_url = config.app.get("deepseek_base_url")
|
||||
if not base_url:
|
||||
base_url = "https://api.deepseek.com"
|
||||
elif llm_provider == "ernie":
|
||||
api_key = config.app.get("ernie_api_key")
|
||||
secret_key = config.app.get("ernie_secret_key")
|
||||
base_url = config.app.get("ernie_base_url")
|
||||
model_name = "***"
|
||||
if not secret_key:
|
||||
raise ValueError(f"{llm_provider}: secret_key is not set, please set it in the config.toml file.")
|
||||
else:
|
||||
raise ValueError("llm_provider is not set, please set it in the config.toml file.")
|
||||
|
||||
@@ -165,6 +172,34 @@ def _generate_response(prompt: str) -> str:
|
||||
logger.info(result)
|
||||
return result["result"]["response"]
|
||||
|
||||
if llm_provider == "ernie":
|
||||
import requests
|
||||
params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key}
|
||||
access_token = requests.post("https://aip.baidubce.com/oauth/2.0/token", params=params).json().get(
|
||||
"access_token")
|
||||
url = f"{base_url}?access_token={access_token}"
|
||||
|
||||
payload = json.dumps({
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}
|
||||
],
|
||||
"temperature": 0.5,
|
||||
"top_p": 0.8,
|
||||
"penalty_score": 1,
|
||||
"disable_search": False,
|
||||
"enable_citation": False,
|
||||
"response_format": "text"
|
||||
})
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
response = requests.request("POST", url, headers=headers, data=payload).json()
|
||||
return response.get("result")
|
||||
|
||||
if llm_provider == "azure":
|
||||
client = AzureOpenAI(
|
||||
api_key=api_key,
|
||||
@@ -239,7 +274,7 @@ Generate a script for a video, depending on the subject of the video.
|
||||
selected_paragraphs = paragraphs[:paragraph_number]
|
||||
|
||||
# Join the selected paragraphs into a single string
|
||||
return "\n\n".join(selected_paragraphs)
|
||||
return "\n\n".join(paragraphs)
|
||||
|
||||
for i in range(_max_retries):
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user