optimize code
This commit is contained in:
@@ -5,9 +5,9 @@ from typing import List
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from openai import AzureOpenAI
|
from openai import AzureOpenAI
|
||||||
import google.generativeai as genai
|
|
||||||
from app.config import config
|
from app.config import config
|
||||||
|
|
||||||
|
|
||||||
def _generate_response(prompt: str) -> str:
|
def _generate_response(prompt: str) -> str:
|
||||||
content = ""
|
content = ""
|
||||||
llm_provider = config.app.get("llm_provider", "openai")
|
llm_provider = config.app.get("llm_provider", "openai")
|
||||||
@@ -29,7 +29,7 @@ def _generate_response(prompt: str) -> str:
|
|||||||
base_url = "https://api.moonshot.cn/v1"
|
base_url = "https://api.moonshot.cn/v1"
|
||||||
elif llm_provider == "ollama":
|
elif llm_provider == "ollama":
|
||||||
# api_key = config.app.get("openai_api_key")
|
# api_key = config.app.get("openai_api_key")
|
||||||
api_key = "ollama" # any string works but you are required to have one
|
api_key = "ollama" # any string works but you are required to have one
|
||||||
model_name = config.app.get("ollama_model_name")
|
model_name = config.app.get("ollama_model_name")
|
||||||
base_url = config.app.get("ollama_base_url", "")
|
base_url = config.app.get("ollama_base_url", "")
|
||||||
if not base_url:
|
if not base_url:
|
||||||
@@ -78,37 +78,38 @@ def _generate_response(prompt: str) -> str:
|
|||||||
return content.replace("\n", "")
|
return content.replace("\n", "")
|
||||||
|
|
||||||
if llm_provider == "gemini":
|
if llm_provider == "gemini":
|
||||||
|
import google.generativeai as genai
|
||||||
genai.configure(api_key=api_key)
|
genai.configure(api_key=api_key)
|
||||||
|
|
||||||
generation_config = {
|
generation_config = {
|
||||||
"temperature": 0.5,
|
"temperature": 0.5,
|
||||||
"top_p": 1,
|
"top_p": 1,
|
||||||
"top_k": 1,
|
"top_k": 1,
|
||||||
"max_output_tokens": 2048,
|
"max_output_tokens": 2048,
|
||||||
}
|
}
|
||||||
|
|
||||||
safety_settings = [
|
safety_settings = [
|
||||||
{
|
{
|
||||||
"category": "HARM_CATEGORY_HARASSMENT",
|
"category": "HARM_CATEGORY_HARASSMENT",
|
||||||
"threshold": "BLOCK_ONLY_HIGH"
|
"threshold": "BLOCK_ONLY_HIGH"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"category": "HARM_CATEGORY_HATE_SPEECH",
|
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||||
"threshold": "BLOCK_ONLY_HIGH"
|
"threshold": "BLOCK_ONLY_HIGH"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||||
"threshold": "BLOCK_ONLY_HIGH"
|
"threshold": "BLOCK_ONLY_HIGH"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||||
"threshold": "BLOCK_ONLY_HIGH"
|
"threshold": "BLOCK_ONLY_HIGH"
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
model = genai.GenerativeModel(model_name=model_name,
|
model = genai.GenerativeModel(model_name=model_name,
|
||||||
generation_config=generation_config,
|
generation_config=generation_config,
|
||||||
safety_settings=safety_settings)
|
safety_settings=safety_settings)
|
||||||
|
|
||||||
convo = model.start_chat(history=[])
|
convo = model.start_chat(history=[])
|
||||||
|
|
||||||
|
|||||||
@@ -143,7 +143,7 @@ def tr(key):
|
|||||||
return loc.get("Translation", {}).get(key, key)
|
return loc.get("Translation", {}).get(key, key)
|
||||||
|
|
||||||
|
|
||||||
with st.expander(tr("Basic Settings"), expanded=True):
|
with st.expander(tr("Basic Settings"), expanded=False):
|
||||||
config_panels = st.columns(3)
|
config_panels = st.columns(3)
|
||||||
left_config_panel = config_panels[0]
|
left_config_panel = config_panels[0]
|
||||||
middle_config_panel = config_panels[1]
|
middle_config_panel = config_panels[1]
|
||||||
|
|||||||
Reference in New Issue
Block a user