support baidu ERNIE llm

This commit is contained in:
harry
2024-07-03 21:12:21 +08:00
parent dca23d99e4
commit 66c81a04bf
5 changed files with 64 additions and 6 deletions

View File

@@ -72,6 +72,13 @@ def _generate_response(prompt: str) -> str:
base_url = config.app.get("deepseek_base_url")
if not base_url:
base_url = "https://api.deepseek.com"
elif llm_provider == "ernie":
api_key = config.app.get("ernie_api_key")
secret_key = config.app.get("ernie_secret_key")
base_url = config.app.get("ernie_base_url")
model_name = "***"
if not secret_key:
raise ValueError(f"{llm_provider}: secret_key is not set, please set it in the config.toml file.")
else:
raise ValueError("llm_provider is not set, please set it in the config.toml file.")
@@ -165,6 +172,34 @@ def _generate_response(prompt: str) -> str:
logger.info(result)
return result["result"]["response"]
if llm_provider == "ernie":
import requests
params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key}
access_token = requests.post("https://aip.baidubce.com/oauth/2.0/token", params=params).json().get(
"access_token")
url = f"{base_url}?access_token={access_token}"
payload = json.dumps({
"messages": [
{
"role": "user",
"content": prompt
}
],
"temperature": 0.5,
"top_p": 0.8,
"penalty_score": 1,
"disable_search": False,
"enable_citation": False,
"response_format": "text"
})
headers = {
'Content-Type': 'application/json'
}
response = requests.request("POST", url, headers=headers, data=payload).json()
return response.get("result")
if llm_provider == "azure":
client = AzureOpenAI(
api_key=api_key,
@@ -239,7 +274,7 @@ Generate a script for a video, depending on the subject of the video.
selected_paragraphs = paragraphs[:paragraph_number]
# Join the selected paragraphs into a single string
return "\n\n".join(selected_paragraphs)
return "\n\n".join(paragraphs)
for i in range(_max_retries):
try: