supported azure openai

This commit is contained in:
harry
2024-03-25 15:20:00 +08:00
parent a52af0e532
commit 4ed3b9fbcc
2 changed files with 32 additions and 10 deletions

View File

@@ -5,7 +5,7 @@ from typing import List
import g4f
from loguru import logger
from openai import OpenAI
from openai import AzureOpenAI
from app.config import config
@@ -23,6 +23,7 @@ def _generate_response(prompt: str) -> str:
messages=[{"role": "user", "content": prompt}],
)
else:
api_version = "" # for azure
if llm_provider == "moonshot":
api_key = config.app.get("moonshot_api_key")
model_name = config.app.get("moonshot_model_name")
@@ -37,6 +38,11 @@ def _generate_response(prompt: str) -> str:
api_key = config.app.get("oneapi_api_key")
model_name = config.app.get("oneapi_model_name")
base_url = config.app.get("oneapi_base_url", "")
elif llm_provider == "azure":
api_key = config.app.get("azure_api_key")
model_name = config.app.get("azure_model_name")
base_url = config.app.get("azure_base_url", "")
api_version = config.app.get("azure_api_version", "2024-02-15-preview")
else:
raise ValueError("llm_provider is not set, please set it in the config.toml file.")
@@ -47,10 +53,17 @@ def _generate_response(prompt: str) -> str:
if not base_url:
raise ValueError(f"{llm_provider}: base_url is not set, please set it in the config.toml file.")
client = OpenAI(
api_key=api_key,
base_url=base_url,
)
if llm_provider == "azure":
client = AzureOpenAI(
api_key=api_key,
api_version=api_version,
azure_endpoint=base_url,
)
else:
client = OpenAI(
api_key=api_key,
base_url=base_url,
)
response = client.chat.completions.create(
model=model_name,