support baidu ERNIE llm

This commit is contained in:
harry
2024-07-03 21:12:21 +08:00
parent dca23d99e4
commit 66c81a04bf
5 changed files with 64 additions and 6 deletions

View File

@@ -174,6 +174,10 @@ if not config.app.get("hide_config", False):
st.session_state['ui_language'] = code
config.ui['language'] = code
# 是否禁用日志显示
hide_log = st.checkbox(tr("Hide Log"), value=config.app.get("hide_log", False))
config.ui['hide_log'] = hide_log
with middle_config_panel:
# openai
# moonshot (月之暗面)
@@ -184,7 +188,7 @@ if not config.app.get("hide_config", False):
# gemini
# ollama
llm_providers = ['OpenAI', 'Moonshot', 'Azure', 'Qwen', 'DeepSeek', 'Gemini', 'Ollama', 'G4f', 'OneAPI',
"Cloudflare"]
"Cloudflare", "ERNIE"]
saved_llm_provider = config.app.get("llm_provider", "OpenAI").lower()
saved_llm_provider_index = 0
for i, provider in enumerate(llm_providers):
@@ -198,6 +202,7 @@ if not config.app.get("hide_config", False):
config.app["llm_provider"] = llm_provider
llm_api_key = config.app.get(f"{llm_provider}_api_key", "")
llm_secret_key = config.app.get(f"{llm_provider}_secret_key", "") # only for baidu ernie
llm_base_url = config.app.get(f"{llm_provider}_base_url", "")
llm_model_name = config.app.get(f"{llm_provider}_model_name", "")
llm_account_id = config.app.get(f"{llm_provider}_account_id", "")
@@ -300,6 +305,15 @@ if not config.app.get("hide_config", False):
- **Model Name**: 固定为 deepseek-chat
"""
if llm_provider == 'ernie':
with llm_helper:
tips = """
##### 百度文心一言 配置说明
- **API Key**: [点击到官网申请](https://console.bce.baidu.com/qianfan/ais/console/applicationConsole/application)
- **Secret Key**: [点击到官网申请](https://console.bce.baidu.com/qianfan/ais/console/applicationConsole/application)
- **Base Url**: 填写 **请求地址** [点击查看文档](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/jlil56u11#%E8%AF%B7%E6%B1%82%E8%AF%B4%E6%98%8E)
"""
if tips and config.ui['language'] == 'zh':
st.warning(
"中国用户建议使用 **DeepSeek** 或 **Moonshot** 作为大模型提供商\n- 国内可直接访问不需要VPN \n- 注册就送额度,基本够用")
@@ -307,7 +321,9 @@ if not config.app.get("hide_config", False):
st_llm_api_key = st.text_input(tr("API Key"), value=llm_api_key, type="password")
st_llm_base_url = st.text_input(tr("Base Url"), value=llm_base_url)
st_llm_model_name = st.text_input(tr("Model Name"), value=llm_model_name)
st_llm_model_name = ""
if llm_provider != 'ernie':
st.text_input(tr("Model Name"), value=llm_model_name)
if st_llm_api_key:
config.app[f"{llm_provider}_api_key"] = st_llm_api_key
@@ -315,6 +331,9 @@ if not config.app.get("hide_config", False):
config.app[f"{llm_provider}_base_url"] = st_llm_base_url
if st_llm_model_name:
config.app[f"{llm_provider}_model_name"] = st_llm_model_name
if llm_provider == 'ernie':
st_llm_secret_key = st.text_input(tr("Secret Key"), value=llm_secret_key, type="password")
config.app[f"{llm_provider}_secret_key"] = st_llm_secret_key
if llm_provider == 'cloudflare':
st_llm_account_id = st.text_input(tr("Account ID"), value=llm_account_id)
@@ -622,6 +641,8 @@ if start_button:
def log_received(msg):
if config.ui['hide_log']:
return
with log_container:
log_records.append(msg)
st.code("\n".join(log_records))