Format project code

This commit is contained in:
yyhhyyyyyy
2024-07-24 14:59:06 +08:00
parent bbd4e94941
commit 905841965a
18 changed files with 410 additions and 214 deletions

View File

@@ -21,6 +21,7 @@ def _generate_response(prompt: str) -> str:
if not model_name:
model_name = "gpt-3.5-turbo-16k-0613"
import g4f
content = g4f.ChatCompletion.create(
model=model_name,
messages=[{"role": "user", "content": prompt}],
@@ -78,44 +79,56 @@ def _generate_response(prompt: str) -> str:
base_url = config.app.get("ernie_base_url")
model_name = "***"
if not secret_key:
raise ValueError(f"{llm_provider}: secret_key is not set, please set it in the config.toml file.")
raise ValueError(
f"{llm_provider}: secret_key is not set, please set it in the config.toml file."
)
else:
raise ValueError("llm_provider is not set, please set it in the config.toml file.")
raise ValueError(
"llm_provider is not set, please set it in the config.toml file."
)
if not api_key:
raise ValueError(f"{llm_provider}: api_key is not set, please set it in the config.toml file.")
raise ValueError(
f"{llm_provider}: api_key is not set, please set it in the config.toml file."
)
if not model_name:
raise ValueError(f"{llm_provider}: model_name is not set, please set it in the config.toml file.")
raise ValueError(
f"{llm_provider}: model_name is not set, please set it in the config.toml file."
)
if not base_url:
raise ValueError(f"{llm_provider}: base_url is not set, please set it in the config.toml file.")
raise ValueError(
f"{llm_provider}: base_url is not set, please set it in the config.toml file."
)
if llm_provider == "qwen":
import dashscope
from dashscope.api_entities.dashscope_response import GenerationResponse
dashscope.api_key = api_key
response = dashscope.Generation.call(
model=model_name,
messages=[{"role": "user", "content": prompt}]
model=model_name, messages=[{"role": "user", "content": prompt}]
)
if response:
if isinstance(response, GenerationResponse):
status_code = response.status_code
if status_code != 200:
raise Exception(
f"[{llm_provider}] returned an error response: \"{response}\"")
f'[{llm_provider}] returned an error response: "{response}"'
)
content = response["output"]["text"]
return content.replace("\n", "")
else:
raise Exception(
f"[{llm_provider}] returned an invalid response: \"{response}\"")
f'[{llm_provider}] returned an invalid response: "{response}"'
)
else:
raise Exception(
f"[{llm_provider}] returned an empty response")
raise Exception(f"[{llm_provider}] returned an empty response")
if llm_provider == "gemini":
import google.generativeai as genai
genai.configure(api_key=api_key, transport='rest')
genai.configure(api_key=api_key, transport="rest")
generation_config = {
"temperature": 0.5,
@@ -127,25 +140,27 @@ def _generate_response(prompt: str) -> str:
safety_settings = [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_ONLY_HIGH"
"threshold": "BLOCK_ONLY_HIGH",
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_ONLY_HIGH"
"threshold": "BLOCK_ONLY_HIGH",
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_ONLY_HIGH"
"threshold": "BLOCK_ONLY_HIGH",
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_ONLY_HIGH"
"threshold": "BLOCK_ONLY_HIGH",
},
]
model = genai.GenerativeModel(model_name=model_name,
generation_config=generation_config,
safety_settings=safety_settings)
model = genai.GenerativeModel(
model_name=model_name,
generation_config=generation_config,
safety_settings=safety_settings,
)
try:
response = model.generate_content(prompt)
@@ -158,15 +173,16 @@ def _generate_response(prompt: str) -> str:
if llm_provider == "cloudflare":
import requests
response = requests.post(
f"https://api.cloudflare.com/client/v4/accounts/{account_id}/ai/run/{model_name}",
headers={"Authorization": f"Bearer {api_key}"},
json={
"messages": [
{"role": "system", "content": "You are a friendly assistant"},
{"role": "user", "content": prompt}
{"role": "user", "content": prompt},
]
}
},
)
result = response.json()
logger.info(result)
@@ -174,30 +190,35 @@ def _generate_response(prompt: str) -> str:
if llm_provider == "ernie":
import requests
params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key}
access_token = requests.post("https://aip.baidubce.com/oauth/2.0/token", params=params).json().get(
"access_token")
params = {
"grant_type": "client_credentials",
"client_id": api_key,
"client_secret": secret_key,
}
access_token = (
requests.post("https://aip.baidubce.com/oauth/2.0/token", params=params)
.json()
.get("access_token")
)
url = f"{base_url}?access_token={access_token}"
payload = json.dumps({
"messages": [
{
"role": "user",
"content": prompt
}
],
"temperature": 0.5,
"top_p": 0.8,
"penalty_score": 1,
"disable_search": False,
"enable_citation": False,
"response_format": "text"
})
headers = {
'Content-Type': 'application/json'
}
payload = json.dumps(
{
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.5,
"top_p": 0.8,
"penalty_score": 1,
"disable_search": False,
"enable_citation": False,
"response_format": "text",
}
)
headers = {"Content-Type": "application/json"}
response = requests.request("POST", url, headers=headers, data=payload).json()
response = requests.request(
"POST", url, headers=headers, data=payload
).json()
return response.get("result")
if llm_provider == "azure":
@@ -213,24 +234,27 @@ def _generate_response(prompt: str) -> str:
)
response = client.chat.completions.create(
model=model_name,
messages=[{"role": "user", "content": prompt}]
model=model_name, messages=[{"role": "user", "content": prompt}]
)
if response:
if isinstance(response, ChatCompletion):
content = response.choices[0].message.content
else:
raise Exception(
f"[{llm_provider}] returned an invalid response: \"{response}\", please check your network "
f"connection and try again.")
f'[{llm_provider}] returned an invalid response: "{response}", please check your network '
f"connection and try again."
)
else:
raise Exception(
f"[{llm_provider}] returned an empty response, please check your network connection and try again.")
f"[{llm_provider}] returned an empty response, please check your network connection and try again."
)
return content.replace("\n", "")
def generate_script(video_subject: str, language: str = "", paragraph_number: int = 1) -> str:
def generate_script(
video_subject: str, language: str = "", paragraph_number: int = 1
) -> str:
prompt = f"""
# Role: Video Script Generator
@@ -335,14 +359,16 @@ Please note that you must use English for generating video search terms; Chinese
try:
response = _generate_response(prompt)
search_terms = json.loads(response)
if not isinstance(search_terms, list) or not all(isinstance(term, str) for term in search_terms):
if not isinstance(search_terms, list) or not all(
isinstance(term, str) for term in search_terms
):
logger.error("response is not a list of strings.")
continue
except Exception as e:
logger.warning(f"failed to generate video terms: {str(e)}")
if response:
match = re.search(r'\[.*]', response)
match = re.search(r"\[.*]", response)
if match:
try:
search_terms = json.loads(match.group())
@@ -361,9 +387,13 @@ Please note that you must use English for generating video search terms; Chinese
if __name__ == "__main__":
video_subject = "生命的意义是什么"
script = generate_script(video_subject=video_subject, language="zh-CN", paragraph_number=1)
script = generate_script(
video_subject=video_subject, language="zh-CN", paragraph_number=1
)
print("######################")
print(script)
search_terms = generate_terms(video_subject=video_subject, video_script=script, amount=5)
search_terms = generate_terms(
video_subject=video_subject, video_script=script, amount=5
)
print("######################")
print(search_terms)