174 lines
8.2 KiB
Python
174 lines
8.2 KiB
Python
from fastapi import APIRouter, Depends, HTTPException
|
|
import aiohttp
|
|
|
|
from services.api_key_service import ApiKeyService
|
|
from db_models import get_db
|
|
from api.api_keys import validate_api_key
|
|
|
|
router = APIRouter(tags=["models"])
|
|
|
|
|
|
@router.get("/models/{provider}")
|
|
async def get_available_models(provider: str, db=Depends(get_db)):
|
|
"""
|
|
Get available models for a specific provider by fetching from their API.
|
|
Falls back to a curated default list if the API key is missing or the request fails.
|
|
"""
|
|
# Curated defaults for each provider — used as fallback
|
|
default_models = {
|
|
"openai": [
|
|
{"model_identifier": "gpt-4o", "display_name": "gpt-4o"},
|
|
{"model_identifier": "gpt-4o-mini", "display_name": "gpt-4o-mini"},
|
|
{"model_identifier": "gpt-4-turbo", "display_name": "gpt-4-turbo"},
|
|
{"model_identifier": "gpt-4", "display_name": "gpt-4"},
|
|
{"model_identifier": "o3-mini", "display_name": "o3-mini"},
|
|
{"model_identifier": "gpt-3.5-turbo", "display_name": "gpt-3.5-turbo"}
|
|
],
|
|
"claude": [
|
|
{"model_identifier": "claude-opus-4-5", "display_name": "claude-opus-4-5"},
|
|
{"model_identifier": "claude-sonnet-4-5", "display_name": "claude-sonnet-4-5"},
|
|
{"model_identifier": "claude-3-5-sonnet-20241022", "display_name": "claude-3-5-sonnet-20241022"},
|
|
{"model_identifier": "claude-3-5-haiku-20241022", "display_name": "claude-3-5-haiku-20241022"},
|
|
{"model_identifier": "claude-3-opus-20240229", "display_name": "claude-3-opus-20240229"}
|
|
],
|
|
"qwen": [
|
|
{"model_identifier": "qwen3-max", "display_name": "qwen3-max"},
|
|
{"model_identifier": "qwen3-plus", "display_name": "qwen3-plus"},
|
|
{"model_identifier": "qwen3-flash", "display_name": "qwen3-flash"},
|
|
{"model_identifier": "qwen-max", "display_name": "qwen-max"},
|
|
{"model_identifier": "qwen-plus", "display_name": "qwen-plus"},
|
|
{"model_identifier": "qwen-turbo", "display_name": "qwen-turbo"}
|
|
],
|
|
"deepseek": [
|
|
{"model_identifier": "deepseek-chat", "display_name": "deepseek-chat"},
|
|
{"model_identifier": "deepseek-reasoner", "display_name": "deepseek-reasoner"},
|
|
{"model_identifier": "deepseek-v3", "display_name": "deepseek-v3"},
|
|
{"model_identifier": "deepseek-r1", "display_name": "deepseek-r1"}
|
|
]
|
|
}
|
|
|
|
defaults = default_models.get(provider, [])
|
|
|
|
try:
|
|
# Retrieve and decrypt API key
|
|
decrypted_key = ApiKeyService.get_api_key(db, provider)
|
|
if not decrypted_key:
|
|
return {"provider": provider, "models": defaults}
|
|
|
|
# ---------- OpenAI ----------
|
|
if provider == "openai":
|
|
import openai
|
|
async with openai.AsyncOpenAI(api_key=decrypted_key, timeout=10.0) as client:
|
|
response = await client.models.list()
|
|
|
|
# Keep only chat / reasoning models, sorted newest-first by created timestamp
|
|
chat_prefixes = ('gpt-', 'o1', 'o3', 'o4', 'chatgpt')
|
|
models = []
|
|
seen = set()
|
|
for m in sorted(response.data, key=lambda x: x.created, reverse=True):
|
|
if m.id not in seen and m.id.startswith(chat_prefixes):
|
|
seen.add(m.id)
|
|
models.append({"model_identifier": m.id, "display_name": m.id})
|
|
|
|
if models:
|
|
return {"provider": provider, "models": models}
|
|
|
|
# ---------- Claude ----------
|
|
elif provider == "claude":
|
|
timeout = aiohttp.ClientTimeout(total=10)
|
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
|
headers = {
|
|
"x-api-key": decrypted_key,
|
|
"anthropic-version": "2023-06-01"
|
|
}
|
|
async with session.get("https://api.anthropic.com/v1/models", headers=headers) as resp:
|
|
if resp.status == 200:
|
|
data = await resp.json()
|
|
models = []
|
|
for m in data.get("data", []):
|
|
model_id = m.get("id", "")
|
|
if model_id.startswith("claude-"):
|
|
models.append({"model_identifier": model_id, "display_name": model_id})
|
|
if models:
|
|
return {"provider": provider, "models": models}
|
|
else:
|
|
print(f"Claude models API returned {resp.status}: {await resp.text()}")
|
|
|
|
# ---------- Qwen ----------
|
|
elif provider == "qwen":
|
|
timeout = aiohttp.ClientTimeout(total=10)
|
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
|
headers = {"Authorization": f"Bearer {decrypted_key}"}
|
|
async with session.get("https://dashscope.aliyuncs.com/compatible-mode/v1/models", headers=headers) as resp:
|
|
if resp.status == 200:
|
|
data = await resp.json()
|
|
exclude_keywords = [
|
|
'tts', 'vl', 'ocr', 'image', 'asr', '-mt-', '-mt',
|
|
'math', 'embed', 'rerank', 'coder', 'translate',
|
|
's2s', 'deep-search', 'omni', 'gui-'
|
|
]
|
|
models = []
|
|
for m in data.get("data", []):
|
|
model_id = m.get("id", "")
|
|
if not model_id:
|
|
continue
|
|
if not (model_id.startswith("qwen") or model_id.startswith("qwq")):
|
|
continue
|
|
if any(kw in model_id for kw in exclude_keywords):
|
|
continue
|
|
models.append({"model_identifier": model_id, "display_name": model_id})
|
|
if models:
|
|
return {"provider": provider, "models": models}
|
|
else:
|
|
print(f"Qwen models API returned {resp.status}: {await resp.text()}")
|
|
|
|
# ---------- DeepSeek ----------
|
|
elif provider == "deepseek":
|
|
timeout = aiohttp.ClientTimeout(total=10)
|
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
|
headers = {"Authorization": f"Bearer {decrypted_key}"}
|
|
async with session.get("https://api.deepseek.com/v1/models", headers=headers) as resp:
|
|
if resp.status == 200:
|
|
data = await resp.json()
|
|
models = []
|
|
for m in data.get("data", []):
|
|
model_id = m.get("id", "")
|
|
if model_id.startswith("deepseek"):
|
|
models.append({"model_identifier": model_id, "display_name": model_id})
|
|
if models:
|
|
return {"provider": provider, "models": models}
|
|
else:
|
|
print(f"DeepSeek models API returned {resp.status}: {await resp.text()}")
|
|
|
|
# API fetch succeeded but returned empty list, or unknown provider — use defaults
|
|
return {"provider": provider, "models": defaults}
|
|
|
|
except Exception as e:
|
|
print(f"Error fetching models for {provider}: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return {"provider": provider, "models": defaults}
|
|
|
|
|
|
@router.get("/providers")
|
|
async def get_available_providers(db=Depends(get_db)):
|
|
"""
|
|
Get all providers that have valid API keys set
|
|
"""
|
|
try:
|
|
available_providers = []
|
|
for provider in ("openai", "claude", "qwen", "deepseek", "tavily"):
|
|
decrypted_key = ApiKeyService.get_api_key(db, provider)
|
|
if not decrypted_key:
|
|
continue
|
|
is_valid = await validate_api_key(provider, decrypted_key)
|
|
if is_valid:
|
|
available_providers.append({
|
|
"provider": provider,
|
|
"has_valid_key": True
|
|
})
|
|
|
|
return {"providers": available_providers}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|