supported azure openai
This commit is contained in:
@@ -5,7 +5,7 @@ from typing import List
|
||||
import g4f
|
||||
from loguru import logger
|
||||
from openai import OpenAI
|
||||
|
||||
from openai import AzureOpenAI
|
||||
from app.config import config
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ def _generate_response(prompt: str) -> str:
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
)
|
||||
else:
|
||||
api_version = "" # for azure
|
||||
if llm_provider == "moonshot":
|
||||
api_key = config.app.get("moonshot_api_key")
|
||||
model_name = config.app.get("moonshot_model_name")
|
||||
@@ -37,6 +38,11 @@ def _generate_response(prompt: str) -> str:
|
||||
api_key = config.app.get("oneapi_api_key")
|
||||
model_name = config.app.get("oneapi_model_name")
|
||||
base_url = config.app.get("oneapi_base_url", "")
|
||||
elif llm_provider == "azure":
|
||||
api_key = config.app.get("azure_api_key")
|
||||
model_name = config.app.get("azure_model_name")
|
||||
base_url = config.app.get("azure_base_url", "")
|
||||
api_version = config.app.get("azure_api_version", "2024-02-15-preview")
|
||||
else:
|
||||
raise ValueError("llm_provider is not set, please set it in the config.toml file.")
|
||||
|
||||
@@ -47,10 +53,17 @@ def _generate_response(prompt: str) -> str:
|
||||
if not base_url:
|
||||
raise ValueError(f"{llm_provider}: base_url is not set, please set it in the config.toml file.")
|
||||
|
||||
client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
)
|
||||
if llm_provider == "azure":
|
||||
client = AzureOpenAI(
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
azure_endpoint=base_url,
|
||||
)
|
||||
else:
|
||||
client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=model_name,
|
||||
|
||||
Reference in New Issue
Block a user