added support for google gemini

This commit is contained in:
PD
2024-03-31 10:44:52 +05:30
committed by GitHub
parent 478207fa7b
commit cc1f157714
3 changed files with 48 additions and 2 deletions

View File

@@ -5,9 +5,9 @@ from typing import List
from loguru import logger from loguru import logger
from openai import OpenAI from openai import OpenAI
from openai import AzureOpenAI from openai import AzureOpenAI
import google.generativeai as genai
from app.config import config from app.config import config
def _generate_response(prompt: str) -> str: def _generate_response(prompt: str) -> str:
content = "" content = ""
llm_provider = config.app.get("llm_provider", "openai") llm_provider = config.app.get("llm_provider", "openai")
@@ -42,6 +42,10 @@ def _generate_response(prompt: str) -> str:
model_name = config.app.get("azure_model_name") model_name = config.app.get("azure_model_name")
base_url = config.app.get("azure_base_url", "") base_url = config.app.get("azure_base_url", "")
api_version = config.app.get("azure_api_version", "2024-02-15-preview") api_version = config.app.get("azure_api_version", "2024-02-15-preview")
elif llm_provider == "gemini":
api_key = config.app.get("gemini_api_key")
model_name = config.app.get("gemini_model_name")
base_url = ""
elif llm_provider == "qwen": elif llm_provider == "qwen":
api_key = config.app.get("qwen_api_key") api_key = config.app.get("qwen_api_key")
model_name = config.app.get("qwen_model_name") model_name = config.app.get("qwen_model_name")
@@ -66,6 +70,44 @@ def _generate_response(prompt: str) -> str:
content = response["output"]["text"] content = response["output"]["text"]
return content.replace("\n", "") return content.replace("\n", "")
if llm_provider == "gemini":
genai.configure(api_key=api_key)
generation_config = {
"temperature": 0.5,
"top_p": 1,
"top_k": 1,
"max_output_tokens": 2048,
}
safety_settings = [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_ONLY_HIGH"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_ONLY_HIGH"
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_ONLY_HIGH"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_ONLY_HIGH"
},
]
model = genai.GenerativeModel(model_name=model_name,
generation_config=generation_config,
safety_settings=safety_settings)
convo = model.start_chat(history=[])
convo.send_message(prompt)
return convo.last.text
if llm_provider == "azure": if llm_provider == "azure":
client = AzureOpenAI( client = AzureOpenAI(
api_key=api_key, api_key=api_key,

View File

@@ -51,6 +51,10 @@
azure_model_name="gpt-35-turbo" # replace with your model deployment name azure_model_name="gpt-35-turbo" # replace with your model deployment name
azure_api_version = "2024-02-15-preview" azure_api_version = "2024-02-15-preview"
########## Gemini API Key
gemini_api_key=""
gemini_model_name = "gemini-1.0-pro"
########## Qwen API Key ########## Qwen API Key
# Visit https://dashscope.console.aliyun.com/apiKey to get your API key # Visit https://dashscope.console.aliyun.com/apiKey to get your API key
# Visit below links to get more details # Visit below links to get more details

View File

@@ -13,4 +13,4 @@ urllib3~=2.2.1
pillow~=9.5.0 pillow~=9.5.0
pydantic~=2.6.3 pydantic~=2.6.3
g4f~=0.2.5.4 g4f~=0.2.5.4
dashscope~=1.15.0 dashscope~=1.15.0