114 lines
4.1 KiB
Python
114 lines
4.1 KiB
Python
from fastapi import APIRouter, Depends, Form, HTTPException
|
|
|
|
from middleware.auth import require_auth
|
|
import aiohttp
|
|
|
|
from services.api_key_service import ApiKeyService
|
|
from db_models import get_db
|
|
|
|
router = APIRouter(tags=["api-keys"])
|
|
|
|
|
|
async def validate_api_key(provider: str, api_key: str):
|
|
"""
|
|
Validate an API key by listing models from the provider.
|
|
All providers validated the same way: if we can list models, the key is valid.
|
|
"""
|
|
try:
|
|
if provider == "openai":
|
|
import openai
|
|
async with openai.AsyncOpenAI(api_key=api_key, timeout=10.0) as client:
|
|
await client.models.list()
|
|
return True
|
|
|
|
# Claude, Qwen, DeepSeek: GET their models endpoint via aiohttp
|
|
endpoints = {
|
|
"claude": {
|
|
"url": "https://api.anthropic.com/v1/models",
|
|
"headers": {
|
|
"x-api-key": api_key,
|
|
"anthropic-version": "2023-06-01"
|
|
}
|
|
},
|
|
"qwen": {
|
|
"url": "https://dashscope.aliyuncs.com/compatible-mode/v1/models",
|
|
"headers": {"Authorization": f"Bearer {api_key}"}
|
|
},
|
|
"deepseek": {
|
|
"url": "https://api.deepseek.com/v1/models",
|
|
"headers": {"Authorization": f"Bearer {api_key}"}
|
|
}
|
|
}
|
|
|
|
if provider in endpoints:
|
|
timeout = aiohttp.ClientTimeout(total=10)
|
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
|
ep = endpoints[provider]
|
|
async with session.get(ep["url"], headers=ep["headers"]) as response:
|
|
return response.status == 200
|
|
|
|
# Tavily: validate by calling REST API directly (no tavily package needed)
|
|
if provider == "tavily":
|
|
timeout = aiohttp.ClientTimeout(total=10)
|
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
|
async with session.post(
|
|
"https://api.tavily.com/search",
|
|
json={"api_key": api_key, "query": "test", "max_results": 1}
|
|
) as response:
|
|
if response.status == 200:
|
|
data = await response.json()
|
|
return bool(data.get("results"))
|
|
return False
|
|
|
|
return False
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
@router.post("/api-keys/{provider}", dependencies=[Depends(require_auth)])
|
|
async def set_api_key(provider: str, api_key: str = Form(...), db=Depends(get_db)):
|
|
"""
|
|
Set API key for a specific provider
|
|
"""
|
|
try:
|
|
success = ApiKeyService.set_api_key(db, provider, api_key)
|
|
if success:
|
|
return {"message": f"API key for {provider} updated successfully"}
|
|
else:
|
|
raise HTTPException(status_code=500, detail="Failed to update API key")
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.get("/api-keys/{provider}")
|
|
async def get_api_key(provider: str, db=Depends(get_db)):
|
|
"""
|
|
Get API key for a specific provider
|
|
"""
|
|
try:
|
|
api_key = ApiKeyService.get_api_key(db, provider)
|
|
if api_key:
|
|
return {"provider": provider, "api_key": api_key}
|
|
else:
|
|
raise HTTPException(status_code=404, detail=f"No API key found for {provider}")
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.post("/validate-api-key/{provider}", dependencies=[Depends(require_auth)])
|
|
async def validate_api_key_endpoint(provider: str, api_key: str = Form(...)):
|
|
"""
|
|
Validate an API key by making a test request to the provider
|
|
This endpoint is used by the frontend to validate API keys without CORS issues
|
|
"""
|
|
try:
|
|
is_valid = await validate_api_key(provider, api_key)
|
|
if is_valid:
|
|
return {"valid": True, "message": f"Valid {provider} API key"}
|
|
else:
|
|
return {"valid": False, "message": f"Invalid {provider} API key"}
|
|
except Exception as e:
|
|
return {"valid": False, "message": str(e)}
|