Merge pull request #127 from pratham-darooka/main
Added support for Google Gemini models
This commit is contained in:
@@ -5,9 +5,9 @@ from typing import List
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from openai import AzureOpenAI
|
from openai import AzureOpenAI
|
||||||
|
import google.generativeai as genai
|
||||||
from app.config import config
|
from app.config import config
|
||||||
|
|
||||||
|
|
||||||
def _generate_response(prompt: str) -> str:
|
def _generate_response(prompt: str) -> str:
|
||||||
content = ""
|
content = ""
|
||||||
llm_provider = config.app.get("llm_provider", "openai")
|
llm_provider = config.app.get("llm_provider", "openai")
|
||||||
@@ -42,6 +42,10 @@ def _generate_response(prompt: str) -> str:
|
|||||||
model_name = config.app.get("azure_model_name")
|
model_name = config.app.get("azure_model_name")
|
||||||
base_url = config.app.get("azure_base_url", "")
|
base_url = config.app.get("azure_base_url", "")
|
||||||
api_version = config.app.get("azure_api_version", "2024-02-15-preview")
|
api_version = config.app.get("azure_api_version", "2024-02-15-preview")
|
||||||
|
elif llm_provider == "gemini":
|
||||||
|
api_key = config.app.get("gemini_api_key")
|
||||||
|
model_name = config.app.get("gemini_model_name")
|
||||||
|
base_url = "***"
|
||||||
elif llm_provider == "qwen":
|
elif llm_provider == "qwen":
|
||||||
api_key = config.app.get("qwen_api_key")
|
api_key = config.app.get("qwen_api_key")
|
||||||
model_name = config.app.get("qwen_model_name")
|
model_name = config.app.get("qwen_model_name")
|
||||||
@@ -66,6 +70,44 @@ def _generate_response(prompt: str) -> str:
|
|||||||
content = response["output"]["text"]
|
content = response["output"]["text"]
|
||||||
return content.replace("\n", "")
|
return content.replace("\n", "")
|
||||||
|
|
||||||
|
if llm_provider == "gemini":
|
||||||
|
genai.configure(api_key=api_key)
|
||||||
|
|
||||||
|
generation_config = {
|
||||||
|
"temperature": 0.5,
|
||||||
|
"top_p": 1,
|
||||||
|
"top_k": 1,
|
||||||
|
"max_output_tokens": 2048,
|
||||||
|
}
|
||||||
|
|
||||||
|
safety_settings = [
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_HARASSMENT",
|
||||||
|
"threshold": "BLOCK_ONLY_HIGH"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||||
|
"threshold": "BLOCK_ONLY_HIGH"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||||
|
"threshold": "BLOCK_ONLY_HIGH"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||||
|
"threshold": "BLOCK_ONLY_HIGH"
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
model = genai.GenerativeModel(model_name=model_name,
|
||||||
|
generation_config=generation_config,
|
||||||
|
safety_settings=safety_settings)
|
||||||
|
|
||||||
|
convo = model.start_chat(history=[])
|
||||||
|
|
||||||
|
convo.send_message(prompt)
|
||||||
|
return convo.last.text
|
||||||
|
|
||||||
if llm_provider == "azure":
|
if llm_provider == "azure":
|
||||||
client = AzureOpenAI(
|
client = AzureOpenAI(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
|||||||
@@ -51,6 +51,10 @@
|
|||||||
azure_model_name="gpt-35-turbo" # replace with your model deployment name
|
azure_model_name="gpt-35-turbo" # replace with your model deployment name
|
||||||
azure_api_version = "2024-02-15-preview"
|
azure_api_version = "2024-02-15-preview"
|
||||||
|
|
||||||
|
########## Gemini API Key
|
||||||
|
gemini_api_key=""
|
||||||
|
gemini_model_name = "gemini-1.0-pro"
|
||||||
|
|
||||||
########## Qwen API Key
|
########## Qwen API Key
|
||||||
# Visit https://dashscope.console.aliyun.com/apiKey to get your API key
|
# Visit https://dashscope.console.aliyun.com/apiKey to get your API key
|
||||||
# Visit below links to get more details
|
# Visit below links to get more details
|
||||||
|
|||||||
@@ -13,4 +13,5 @@ urllib3~=2.2.1
|
|||||||
pillow~=9.5.0
|
pillow~=9.5.0
|
||||||
pydantic~=2.6.3
|
pydantic~=2.6.3
|
||||||
g4f~=0.2.5.4
|
g4f~=0.2.5.4
|
||||||
dashscope~=1.15.0
|
dashscope~=1.15.0
|
||||||
|
google.generativeai~=0.4.1
|
||||||
Reference in New Issue
Block a user