init
This commit is contained in:
4
api/__init__.py
Normal file
4
api/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from api.debates import router as debates_router
|
||||
from api.api_keys import router as api_keys_router
|
||||
from api.models import router as models_router
|
||||
from api.setup import router as setup_router
|
||||
113
api/api_keys.py
Normal file
113
api/api_keys.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from fastapi import APIRouter, Depends, Form, HTTPException
|
||||
|
||||
from middleware.auth import require_auth
|
||||
import aiohttp
|
||||
|
||||
from services.api_key_service import ApiKeyService
|
||||
from db_models import get_db
|
||||
|
||||
router = APIRouter(tags=["api-keys"])
|
||||
|
||||
|
||||
async def validate_api_key(provider: str, api_key: str):
|
||||
"""
|
||||
Validate an API key by listing models from the provider.
|
||||
All providers validated the same way: if we can list models, the key is valid.
|
||||
"""
|
||||
try:
|
||||
if provider == "openai":
|
||||
import openai
|
||||
async with openai.AsyncOpenAI(api_key=api_key, timeout=10.0) as client:
|
||||
await client.models.list()
|
||||
return True
|
||||
|
||||
# Claude, Qwen, DeepSeek: GET their models endpoint via aiohttp
|
||||
endpoints = {
|
||||
"claude": {
|
||||
"url": "https://api.anthropic.com/v1/models",
|
||||
"headers": {
|
||||
"x-api-key": api_key,
|
||||
"anthropic-version": "2023-06-01"
|
||||
}
|
||||
},
|
||||
"qwen": {
|
||||
"url": "https://dashscope.aliyuncs.com/compatible-mode/v1/models",
|
||||
"headers": {"Authorization": f"Bearer {api_key}"}
|
||||
},
|
||||
"deepseek": {
|
||||
"url": "https://api.deepseek.com/v1/models",
|
||||
"headers": {"Authorization": f"Bearer {api_key}"}
|
||||
}
|
||||
}
|
||||
|
||||
if provider in endpoints:
|
||||
timeout = aiohttp.ClientTimeout(total=10)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
ep = endpoints[provider]
|
||||
async with session.get(ep["url"], headers=ep["headers"]) as response:
|
||||
return response.status == 200
|
||||
|
||||
# Tavily: validate by calling REST API directly (no tavily package needed)
|
||||
if provider == "tavily":
|
||||
timeout = aiohttp.ClientTimeout(total=10)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.post(
|
||||
"https://api.tavily.com/search",
|
||||
json={"api_key": api_key, "query": "test", "max_results": 1}
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
return bool(data.get("results"))
|
||||
return False
|
||||
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
@router.post("/api-keys/{provider}", dependencies=[Depends(require_auth)])
|
||||
async def set_api_key(provider: str, api_key: str = Form(...), db=Depends(get_db)):
|
||||
"""
|
||||
Set API key for a specific provider
|
||||
"""
|
||||
try:
|
||||
success = ApiKeyService.set_api_key(db, provider, api_key)
|
||||
if success:
|
||||
return {"message": f"API key for {provider} updated successfully"}
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Failed to update API key")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/api-keys/{provider}")
|
||||
async def get_api_key(provider: str, db=Depends(get_db)):
|
||||
"""
|
||||
Get API key for a specific provider
|
||||
"""
|
||||
try:
|
||||
api_key = ApiKeyService.get_api_key(db, provider)
|
||||
if api_key:
|
||||
return {"provider": provider, "api_key": api_key}
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail=f"No API key found for {provider}")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/validate-api-key/{provider}", dependencies=[Depends(require_auth)])
|
||||
async def validate_api_key_endpoint(provider: str, api_key: str = Form(...)):
|
||||
"""
|
||||
Validate an API key by making a test request to the provider
|
||||
This endpoint is used by the frontend to validate API keys without CORS issues
|
||||
"""
|
||||
try:
|
||||
is_valid = await validate_api_key(provider, api_key)
|
||||
if is_valid:
|
||||
return {"valid": True, "message": f"Valid {provider} API key"}
|
||||
else:
|
||||
return {"valid": False, "message": f"Invalid {provider} API key"}
|
||||
except Exception as e:
|
||||
return {"valid": False, "message": str(e)}
|
||||
158
api/debates.py
Normal file
158
api/debates.py
Normal file
@@ -0,0 +1,158 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from middleware.auth import require_auth
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from typing import Dict, Any
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from orchestrator.debate_orchestrator import DebateOrchestrator
|
||||
from models.debate import DebateRequest
|
||||
from storage.session_manager import SessionManager
|
||||
from db_models import get_db
|
||||
|
||||
router = APIRouter(tags=["debates"])
|
||||
|
||||
|
||||
@router.post("/debate/create", dependencies=[Depends(require_auth)])
|
||||
async def create_debate(debate_request: DebateRequest, db=Depends(get_db)) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a new debate session with specified parameters
|
||||
"""
|
||||
try:
|
||||
orchestrator = DebateOrchestrator(db)
|
||||
session_id = await orchestrator.create_session(debate_request)
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"status": "created",
|
||||
"message": f"Debate session {session_id} created successfully"
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/debate/{session_id}")
|
||||
async def get_debate(session_id: str, db=Depends(get_db)) -> Dict[str, Any]:
|
||||
"""
|
||||
Get the current state of a debate session
|
||||
"""
|
||||
try:
|
||||
orchestrator = DebateOrchestrator(db)
|
||||
session = await orchestrator.get_session_status(session_id)
|
||||
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Debate session {session_id} not found")
|
||||
|
||||
return session.dict()
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/debate/{session_id}/start", dependencies=[Depends(require_auth)])
|
||||
async def start_debate(session_id: str, db=Depends(get_db)) -> Dict[str, Any]:
|
||||
"""
|
||||
Start a debate session and stream the results
|
||||
"""
|
||||
try:
|
||||
orchestrator = DebateOrchestrator(db)
|
||||
session = await orchestrator.run_debate(session_id)
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"status": session.status,
|
||||
"message": f"Debate session {session_id} completed"
|
||||
}
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/debate/{session_id}", dependencies=[Depends(require_auth)])
|
||||
async def end_debate(session_id: str, db=Depends(get_db)) -> Dict[str, Any]:
|
||||
"""
|
||||
End a debate session prematurely
|
||||
"""
|
||||
try:
|
||||
orchestrator = DebateOrchestrator(db)
|
||||
await orchestrator.terminate_session(session_id)
|
||||
|
||||
return {"session_id": session_id, "status": "terminated"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/debate/{session_id}/stream")
|
||||
async def stream_debate(session_id: str, db=Depends(get_db)) -> EventSourceResponse:
|
||||
"""
|
||||
Stream debate updates in real-time using Server-Sent Events
|
||||
"""
|
||||
async def event_generator():
|
||||
session_manager = SessionManager()
|
||||
|
||||
# Check if the session exists
|
||||
session = await session_manager.get_session(db, session_id)
|
||||
if not session:
|
||||
yield {"event": "error", "data": json.dumps({"error": f"Session {session_id} not found"})}
|
||||
return
|
||||
|
||||
# Yield initial state
|
||||
yield {"event": "update", "data": json.dumps({
|
||||
"session_id": session_id,
|
||||
"status": session.status,
|
||||
"rounds": [round.dict() for round in session.rounds],
|
||||
"evidence_library": [e.dict() for e in session.evidence_library],
|
||||
"message": "Debate session initialized"
|
||||
}, default=str)}
|
||||
|
||||
last_round_count = len(session.rounds)
|
||||
# Poll until debate completes or times out (max 5 min)
|
||||
for _ in range(150): # 150 × 2s = 300s timeout
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Reset transaction so we see commits from the run_debate request
|
||||
db.commit()
|
||||
|
||||
updated_session = await session_manager.get_session(db, session_id)
|
||||
if not updated_session:
|
||||
break
|
||||
|
||||
# Only yield update when rounds actually changed
|
||||
if len(updated_session.rounds) != last_round_count or updated_session.status != session.status:
|
||||
last_round_count = len(updated_session.rounds)
|
||||
session = updated_session
|
||||
yield {"event": "update", "data": json.dumps({
|
||||
"session_id": session_id,
|
||||
"status": updated_session.status,
|
||||
"rounds": [round.dict() for round in updated_session.rounds],
|
||||
"evidence_library": [e.dict() for e in updated_session.evidence_library],
|
||||
"current_round": len(updated_session.rounds)
|
||||
}, default=str)}
|
||||
|
||||
if updated_session.status in ("completed", "terminated"):
|
||||
yield {"event": "complete", "data": json.dumps({
|
||||
"session_id": session_id,
|
||||
"status": updated_session.status,
|
||||
"summary": updated_session.summary,
|
||||
"rounds": [round.dict() for round in updated_session.rounds],
|
||||
"evidence_library": [e.dict() for e in updated_session.evidence_library]
|
||||
}, default=str)}
|
||||
break
|
||||
|
||||
return EventSourceResponse(event_generator())
|
||||
|
||||
|
||||
@router.get("/sessions")
|
||||
async def list_sessions(db=Depends(get_db)) -> Dict[str, Any]:
|
||||
"""
|
||||
List all debate sessions
|
||||
"""
|
||||
try:
|
||||
session_manager = SessionManager()
|
||||
sessions = await session_manager.list_sessions(db)
|
||||
return {"sessions": sessions}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
173
api/models.py
Normal file
173
api/models.py
Normal file
@@ -0,0 +1,173 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
import aiohttp
|
||||
|
||||
from services.api_key_service import ApiKeyService
|
||||
from db_models import get_db
|
||||
from api.api_keys import validate_api_key
|
||||
|
||||
router = APIRouter(tags=["models"])
|
||||
|
||||
|
||||
@router.get("/models/{provider}")
|
||||
async def get_available_models(provider: str, db=Depends(get_db)):
|
||||
"""
|
||||
Get available models for a specific provider by fetching from their API.
|
||||
Falls back to a curated default list if the API key is missing or the request fails.
|
||||
"""
|
||||
# Curated defaults for each provider — used as fallback
|
||||
default_models = {
|
||||
"openai": [
|
||||
{"model_identifier": "gpt-4o", "display_name": "gpt-4o"},
|
||||
{"model_identifier": "gpt-4o-mini", "display_name": "gpt-4o-mini"},
|
||||
{"model_identifier": "gpt-4-turbo", "display_name": "gpt-4-turbo"},
|
||||
{"model_identifier": "gpt-4", "display_name": "gpt-4"},
|
||||
{"model_identifier": "o3-mini", "display_name": "o3-mini"},
|
||||
{"model_identifier": "gpt-3.5-turbo", "display_name": "gpt-3.5-turbo"}
|
||||
],
|
||||
"claude": [
|
||||
{"model_identifier": "claude-opus-4-5", "display_name": "claude-opus-4-5"},
|
||||
{"model_identifier": "claude-sonnet-4-5", "display_name": "claude-sonnet-4-5"},
|
||||
{"model_identifier": "claude-3-5-sonnet-20241022", "display_name": "claude-3-5-sonnet-20241022"},
|
||||
{"model_identifier": "claude-3-5-haiku-20241022", "display_name": "claude-3-5-haiku-20241022"},
|
||||
{"model_identifier": "claude-3-opus-20240229", "display_name": "claude-3-opus-20240229"}
|
||||
],
|
||||
"qwen": [
|
||||
{"model_identifier": "qwen3-max", "display_name": "qwen3-max"},
|
||||
{"model_identifier": "qwen3-plus", "display_name": "qwen3-plus"},
|
||||
{"model_identifier": "qwen3-flash", "display_name": "qwen3-flash"},
|
||||
{"model_identifier": "qwen-max", "display_name": "qwen-max"},
|
||||
{"model_identifier": "qwen-plus", "display_name": "qwen-plus"},
|
||||
{"model_identifier": "qwen-turbo", "display_name": "qwen-turbo"}
|
||||
],
|
||||
"deepseek": [
|
||||
{"model_identifier": "deepseek-chat", "display_name": "deepseek-chat"},
|
||||
{"model_identifier": "deepseek-reasoner", "display_name": "deepseek-reasoner"},
|
||||
{"model_identifier": "deepseek-v3", "display_name": "deepseek-v3"},
|
||||
{"model_identifier": "deepseek-r1", "display_name": "deepseek-r1"}
|
||||
]
|
||||
}
|
||||
|
||||
defaults = default_models.get(provider, [])
|
||||
|
||||
try:
|
||||
# Retrieve and decrypt API key
|
||||
decrypted_key = ApiKeyService.get_api_key(db, provider)
|
||||
if not decrypted_key:
|
||||
return {"provider": provider, "models": defaults}
|
||||
|
||||
# ---------- OpenAI ----------
|
||||
if provider == "openai":
|
||||
import openai
|
||||
async with openai.AsyncOpenAI(api_key=decrypted_key, timeout=10.0) as client:
|
||||
response = await client.models.list()
|
||||
|
||||
# Keep only chat / reasoning models, sorted newest-first by created timestamp
|
||||
chat_prefixes = ('gpt-', 'o1', 'o3', 'o4', 'chatgpt')
|
||||
models = []
|
||||
seen = set()
|
||||
for m in sorted(response.data, key=lambda x: x.created, reverse=True):
|
||||
if m.id not in seen and m.id.startswith(chat_prefixes):
|
||||
seen.add(m.id)
|
||||
models.append({"model_identifier": m.id, "display_name": m.id})
|
||||
|
||||
if models:
|
||||
return {"provider": provider, "models": models}
|
||||
|
||||
# ---------- Claude ----------
|
||||
elif provider == "claude":
|
||||
timeout = aiohttp.ClientTimeout(total=10)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
headers = {
|
||||
"x-api-key": decrypted_key,
|
||||
"anthropic-version": "2023-06-01"
|
||||
}
|
||||
async with session.get("https://api.anthropic.com/v1/models", headers=headers) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
models = []
|
||||
for m in data.get("data", []):
|
||||
model_id = m.get("id", "")
|
||||
if model_id.startswith("claude-"):
|
||||
models.append({"model_identifier": model_id, "display_name": model_id})
|
||||
if models:
|
||||
return {"provider": provider, "models": models}
|
||||
else:
|
||||
print(f"Claude models API returned {resp.status}: {await resp.text()}")
|
||||
|
||||
# ---------- Qwen ----------
|
||||
elif provider == "qwen":
|
||||
timeout = aiohttp.ClientTimeout(total=10)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
headers = {"Authorization": f"Bearer {decrypted_key}"}
|
||||
async with session.get("https://dashscope.aliyuncs.com/compatible-mode/v1/models", headers=headers) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
exclude_keywords = [
|
||||
'tts', 'vl', 'ocr', 'image', 'asr', '-mt-', '-mt',
|
||||
'math', 'embed', 'rerank', 'coder', 'translate',
|
||||
's2s', 'deep-search', 'omni', 'gui-'
|
||||
]
|
||||
models = []
|
||||
for m in data.get("data", []):
|
||||
model_id = m.get("id", "")
|
||||
if not model_id:
|
||||
continue
|
||||
if not (model_id.startswith("qwen") or model_id.startswith("qwq")):
|
||||
continue
|
||||
if any(kw in model_id for kw in exclude_keywords):
|
||||
continue
|
||||
models.append({"model_identifier": model_id, "display_name": model_id})
|
||||
if models:
|
||||
return {"provider": provider, "models": models}
|
||||
else:
|
||||
print(f"Qwen models API returned {resp.status}: {await resp.text()}")
|
||||
|
||||
# ---------- DeepSeek ----------
|
||||
elif provider == "deepseek":
|
||||
timeout = aiohttp.ClientTimeout(total=10)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
headers = {"Authorization": f"Bearer {decrypted_key}"}
|
||||
async with session.get("https://api.deepseek.com/v1/models", headers=headers) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
models = []
|
||||
for m in data.get("data", []):
|
||||
model_id = m.get("id", "")
|
||||
if model_id.startswith("deepseek"):
|
||||
models.append({"model_identifier": model_id, "display_name": model_id})
|
||||
if models:
|
||||
return {"provider": provider, "models": models}
|
||||
else:
|
||||
print(f"DeepSeek models API returned {resp.status}: {await resp.text()}")
|
||||
|
||||
# API fetch succeeded but returned empty list, or unknown provider — use defaults
|
||||
return {"provider": provider, "models": defaults}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error fetching models for {provider}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return {"provider": provider, "models": defaults}
|
||||
|
||||
|
||||
@router.get("/providers")
|
||||
async def get_available_providers(db=Depends(get_db)):
|
||||
"""
|
||||
Get all providers that have valid API keys set
|
||||
"""
|
||||
try:
|
||||
available_providers = []
|
||||
for provider in ("openai", "claude", "qwen", "deepseek", "tavily"):
|
||||
decrypted_key = ApiKeyService.get_api_key(db, provider)
|
||||
if not decrypted_key:
|
||||
continue
|
||||
is_valid = await validate_api_key(provider, decrypted_key)
|
||||
if is_valid:
|
||||
available_providers.append({
|
||||
"provider": provider,
|
||||
"has_valid_key": True
|
||||
})
|
||||
|
||||
return {"providers": available_providers}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
192
api/setup.py
Normal file
192
api/setup.py
Normal file
@@ -0,0 +1,192 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from pydantic import BaseModel
|
||||
|
||||
from services.config_service import ConfigService
|
||||
|
||||
router = APIRouter(prefix="/api/setup", tags=["setup"])
|
||||
|
||||
PASSWORD_PLACEHOLDER = "********"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request / response models
|
||||
# ---------------------------------------------------------------------------
|
||||
class DatabaseConfig(BaseModel):
|
||||
host: str
|
||||
port: int = 3306
|
||||
user: str = "root"
|
||||
password: str = ""
|
||||
database: str = "dialectica"
|
||||
|
||||
|
||||
class KeycloakConfig(BaseModel):
|
||||
host: str = ""
|
||||
realm: str = ""
|
||||
client_id: str = ""
|
||||
|
||||
|
||||
class TlsConfig(BaseModel):
|
||||
cert_path: str = ""
|
||||
key_path: str = ""
|
||||
force_https: bool = False
|
||||
|
||||
|
||||
class FullConfig(BaseModel):
|
||||
database: Optional[DatabaseConfig] = None
|
||||
keycloak: Optional[KeycloakConfig] = None
|
||||
tls: Optional[TlsConfig] = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Access control dependency
|
||||
# ---------------------------------------------------------------------------
|
||||
async def setup_guard(request: Request):
|
||||
"""Three-phase access control for setup routes.
|
||||
|
||||
1. Not initialized → only localhost allowed
|
||||
2. ENV_MODE=dev → open
|
||||
3. ENV_MODE=prod → Keycloak admin JWT required
|
||||
"""
|
||||
config = ConfigService.load()
|
||||
initialized = config.get("initialized", False)
|
||||
env_mode = os.getenv("ENV_MODE", "dev")
|
||||
|
||||
if env_mode == "dev":
|
||||
return # dev mode: no auth needed, even before initialisation
|
||||
|
||||
if not initialized:
|
||||
# prod + not initialised: only localhost may configure
|
||||
client_ip = request.client.host
|
||||
if client_ip not in ("127.0.0.1", "::1"):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="初次设置仅允许从本机访问",
|
||||
)
|
||||
return
|
||||
|
||||
# prod → delegate to Keycloak middleware (Phase 3)
|
||||
from app.middleware.auth import require_admin
|
||||
await require_admin(request, config)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Routes
|
||||
# ---------------------------------------------------------------------------
|
||||
@router.get("/status")
|
||||
async def setup_status():
|
||||
"""Return current system initialisation state, including KC info for OIDC."""
|
||||
config = ConfigService.load()
|
||||
env_mode = os.getenv("ENV_MODE", "dev")
|
||||
|
||||
result = {
|
||||
"initialized": config.get("initialized", False),
|
||||
"env_mode": env_mode,
|
||||
"db_configured": ConfigService.is_db_configured(),
|
||||
}
|
||||
|
||||
# Include Keycloak info so the frontend can build OIDC config
|
||||
kc = config.get("keycloak", {})
|
||||
if env_mode == "prod" and kc.get("host") and kc.get("realm"):
|
||||
result["keycloak"] = {
|
||||
"authority": f"{kc['host']}/realms/{kc['realm']}",
|
||||
"client_id": kc.get("client_id", ""),
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/config", dependencies=[Depends(setup_guard)])
|
||||
async def get_config():
|
||||
"""Return full config with passwords replaced by placeholder."""
|
||||
config = ConfigService.load()
|
||||
if "database" in config and config["database"].get("password"):
|
||||
config["database"]["password"] = PASSWORD_PLACEHOLDER
|
||||
return config
|
||||
|
||||
|
||||
@router.put("/config", dependencies=[Depends(setup_guard)])
|
||||
async def update_config(payload: FullConfig):
|
||||
"""Merge submitted config sections into the YAML file."""
|
||||
config = ConfigService.load()
|
||||
|
||||
if payload.database is not None:
|
||||
db_dict = payload.database.model_dump()
|
||||
# If password is the placeholder, keep the existing real password
|
||||
if db_dict.get("password") == PASSWORD_PLACEHOLDER:
|
||||
db_dict["password"] = config.get("database", {}).get("password", "")
|
||||
config["database"] = db_dict
|
||||
if payload.keycloak is not None:
|
||||
config["keycloak"] = payload.keycloak.model_dump()
|
||||
if payload.tls is not None:
|
||||
config["tls"] = payload.tls.model_dump()
|
||||
|
||||
ConfigService.save(config)
|
||||
return {"message": "配置已保存"}
|
||||
|
||||
|
||||
@router.post("/test-db", dependencies=[Depends(setup_guard)])
|
||||
async def test_db_connection(db_config: DatabaseConfig):
|
||||
"""Test a database connection with the provided parameters (no save)."""
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
password = db_config.password
|
||||
# If password is the placeholder, use the real password from config
|
||||
if password == PASSWORD_PLACEHOLDER:
|
||||
password = ConfigService.load().get("database", {}).get("password", "")
|
||||
|
||||
url = (
|
||||
f"mysql+pymysql://{quote_plus(db_config.user)}:{quote_plus(password)}"
|
||||
f"@{db_config.host}:{db_config.port}/{db_config.database}"
|
||||
)
|
||||
try:
|
||||
engine = create_engine(url, pool_pre_ping=True)
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text("SELECT 1"))
|
||||
engine.dispose()
|
||||
return {"success": True, "message": "数据库连接成功"}
|
||||
except Exception as e:
|
||||
return {"success": False, "message": str(e)}
|
||||
|
||||
|
||||
@router.post("/test-keycloak", dependencies=[Depends(setup_guard)])
|
||||
async def test_keycloak(kc_config: KeycloakConfig):
|
||||
"""Test Keycloak connectivity by fetching the OIDC discovery document."""
|
||||
import httpx
|
||||
|
||||
well_known = (
|
||||
f"{kc_config.host}/realms/{kc_config.realm}"
|
||||
f"/.well-known/openid-configuration"
|
||||
)
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
resp = await client.get(well_known)
|
||||
if resp.status_code == 200:
|
||||
return {"success": True, "message": "Keycloak 连通正常"}
|
||||
return {"success": False, "message": f"HTTP {resp.status_code}"}
|
||||
except Exception as e:
|
||||
return {"success": False, "message": str(e)}
|
||||
|
||||
|
||||
@router.post("/initialize", dependencies=[Depends(setup_guard)])
|
||||
async def initialize():
|
||||
"""Mark system as initialised and reload DB connection."""
|
||||
config = ConfigService.load()
|
||||
|
||||
if not ConfigService.is_db_configured():
|
||||
raise HTTPException(status_code=400, detail="请先配置数据库连接")
|
||||
|
||||
# Reload DB engine so business routes can start working
|
||||
from app.db_models import reload_db_connection
|
||||
from app.storage.database import init_db
|
||||
|
||||
reload_db_connection()
|
||||
init_db()
|
||||
|
||||
config["initialized"] = True
|
||||
ConfigService.save(config)
|
||||
|
||||
return {"message": "系统初始化完成", "initialized": True}
|
||||
Reference in New Issue
Block a user