init
This commit is contained in:
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
/.idea/
|
||||
16
Dockerfile
Normal file
16
Dockerfile
Normal file
@@ -0,0 +1,16 @@
|
||||
FROM python:3.13-slim
|
||||
|
||||
ENV PYTHONDONTWRITEBYTECODE 1
|
||||
ENV PYTHONUNBUFFERED 1
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY requirements.txt ./requirements.txt
|
||||
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
COPY . /app/
|
||||
|
||||
EXPOSE 5001
|
||||
|
||||
CMD ["python3", "app.py"]
|
||||
4
api/__init__.py
Normal file
4
api/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from api.debates import router as debates_router
|
||||
from api.api_keys import router as api_keys_router
|
||||
from api.models import router as models_router
|
||||
from api.setup import router as setup_router
|
||||
113
api/api_keys.py
Normal file
113
api/api_keys.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from fastapi import APIRouter, Depends, Form, HTTPException
|
||||
|
||||
from middleware.auth import require_auth
|
||||
import aiohttp
|
||||
|
||||
from services.api_key_service import ApiKeyService
|
||||
from db_models import get_db
|
||||
|
||||
router = APIRouter(tags=["api-keys"])
|
||||
|
||||
|
||||
async def validate_api_key(provider: str, api_key: str):
|
||||
"""
|
||||
Validate an API key by listing models from the provider.
|
||||
All providers validated the same way: if we can list models, the key is valid.
|
||||
"""
|
||||
try:
|
||||
if provider == "openai":
|
||||
import openai
|
||||
async with openai.AsyncOpenAI(api_key=api_key, timeout=10.0) as client:
|
||||
await client.models.list()
|
||||
return True
|
||||
|
||||
# Claude, Qwen, DeepSeek: GET their models endpoint via aiohttp
|
||||
endpoints = {
|
||||
"claude": {
|
||||
"url": "https://api.anthropic.com/v1/models",
|
||||
"headers": {
|
||||
"x-api-key": api_key,
|
||||
"anthropic-version": "2023-06-01"
|
||||
}
|
||||
},
|
||||
"qwen": {
|
||||
"url": "https://dashscope.aliyuncs.com/compatible-mode/v1/models",
|
||||
"headers": {"Authorization": f"Bearer {api_key}"}
|
||||
},
|
||||
"deepseek": {
|
||||
"url": "https://api.deepseek.com/v1/models",
|
||||
"headers": {"Authorization": f"Bearer {api_key}"}
|
||||
}
|
||||
}
|
||||
|
||||
if provider in endpoints:
|
||||
timeout = aiohttp.ClientTimeout(total=10)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
ep = endpoints[provider]
|
||||
async with session.get(ep["url"], headers=ep["headers"]) as response:
|
||||
return response.status == 200
|
||||
|
||||
# Tavily: validate by calling REST API directly (no tavily package needed)
|
||||
if provider == "tavily":
|
||||
timeout = aiohttp.ClientTimeout(total=10)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.post(
|
||||
"https://api.tavily.com/search",
|
||||
json={"api_key": api_key, "query": "test", "max_results": 1}
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
return bool(data.get("results"))
|
||||
return False
|
||||
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
@router.post("/api-keys/{provider}", dependencies=[Depends(require_auth)])
|
||||
async def set_api_key(provider: str, api_key: str = Form(...), db=Depends(get_db)):
|
||||
"""
|
||||
Set API key for a specific provider
|
||||
"""
|
||||
try:
|
||||
success = ApiKeyService.set_api_key(db, provider, api_key)
|
||||
if success:
|
||||
return {"message": f"API key for {provider} updated successfully"}
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Failed to update API key")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/api-keys/{provider}")
|
||||
async def get_api_key(provider: str, db=Depends(get_db)):
|
||||
"""
|
||||
Get API key for a specific provider
|
||||
"""
|
||||
try:
|
||||
api_key = ApiKeyService.get_api_key(db, provider)
|
||||
if api_key:
|
||||
return {"provider": provider, "api_key": api_key}
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail=f"No API key found for {provider}")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/validate-api-key/{provider}", dependencies=[Depends(require_auth)])
|
||||
async def validate_api_key_endpoint(provider: str, api_key: str = Form(...)):
|
||||
"""
|
||||
Validate an API key by making a test request to the provider
|
||||
This endpoint is used by the frontend to validate API keys without CORS issues
|
||||
"""
|
||||
try:
|
||||
is_valid = await validate_api_key(provider, api_key)
|
||||
if is_valid:
|
||||
return {"valid": True, "message": f"Valid {provider} API key"}
|
||||
else:
|
||||
return {"valid": False, "message": f"Invalid {provider} API key"}
|
||||
except Exception as e:
|
||||
return {"valid": False, "message": str(e)}
|
||||
158
api/debates.py
Normal file
158
api/debates.py
Normal file
@@ -0,0 +1,158 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from middleware.auth import require_auth
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from typing import Dict, Any
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from orchestrator.debate_orchestrator import DebateOrchestrator
|
||||
from models.debate import DebateRequest
|
||||
from storage.session_manager import SessionManager
|
||||
from db_models import get_db
|
||||
|
||||
router = APIRouter(tags=["debates"])
|
||||
|
||||
|
||||
@router.post("/debate/create", dependencies=[Depends(require_auth)])
|
||||
async def create_debate(debate_request: DebateRequest, db=Depends(get_db)) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a new debate session with specified parameters
|
||||
"""
|
||||
try:
|
||||
orchestrator = DebateOrchestrator(db)
|
||||
session_id = await orchestrator.create_session(debate_request)
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"status": "created",
|
||||
"message": f"Debate session {session_id} created successfully"
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/debate/{session_id}")
|
||||
async def get_debate(session_id: str, db=Depends(get_db)) -> Dict[str, Any]:
|
||||
"""
|
||||
Get the current state of a debate session
|
||||
"""
|
||||
try:
|
||||
orchestrator = DebateOrchestrator(db)
|
||||
session = await orchestrator.get_session_status(session_id)
|
||||
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Debate session {session_id} not found")
|
||||
|
||||
return session.dict()
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/debate/{session_id}/start", dependencies=[Depends(require_auth)])
|
||||
async def start_debate(session_id: str, db=Depends(get_db)) -> Dict[str, Any]:
|
||||
"""
|
||||
Start a debate session and stream the results
|
||||
"""
|
||||
try:
|
||||
orchestrator = DebateOrchestrator(db)
|
||||
session = await orchestrator.run_debate(session_id)
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"status": session.status,
|
||||
"message": f"Debate session {session_id} completed"
|
||||
}
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/debate/{session_id}", dependencies=[Depends(require_auth)])
|
||||
async def end_debate(session_id: str, db=Depends(get_db)) -> Dict[str, Any]:
|
||||
"""
|
||||
End a debate session prematurely
|
||||
"""
|
||||
try:
|
||||
orchestrator = DebateOrchestrator(db)
|
||||
await orchestrator.terminate_session(session_id)
|
||||
|
||||
return {"session_id": session_id, "status": "terminated"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/debate/{session_id}/stream")
|
||||
async def stream_debate(session_id: str, db=Depends(get_db)) -> EventSourceResponse:
|
||||
"""
|
||||
Stream debate updates in real-time using Server-Sent Events
|
||||
"""
|
||||
async def event_generator():
|
||||
session_manager = SessionManager()
|
||||
|
||||
# Check if the session exists
|
||||
session = await session_manager.get_session(db, session_id)
|
||||
if not session:
|
||||
yield {"event": "error", "data": json.dumps({"error": f"Session {session_id} not found"})}
|
||||
return
|
||||
|
||||
# Yield initial state
|
||||
yield {"event": "update", "data": json.dumps({
|
||||
"session_id": session_id,
|
||||
"status": session.status,
|
||||
"rounds": [round.dict() for round in session.rounds],
|
||||
"evidence_library": [e.dict() for e in session.evidence_library],
|
||||
"message": "Debate session initialized"
|
||||
}, default=str)}
|
||||
|
||||
last_round_count = len(session.rounds)
|
||||
# Poll until debate completes or times out (max 5 min)
|
||||
for _ in range(150): # 150 × 2s = 300s timeout
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Reset transaction so we see commits from the run_debate request
|
||||
db.commit()
|
||||
|
||||
updated_session = await session_manager.get_session(db, session_id)
|
||||
if not updated_session:
|
||||
break
|
||||
|
||||
# Only yield update when rounds actually changed
|
||||
if len(updated_session.rounds) != last_round_count or updated_session.status != session.status:
|
||||
last_round_count = len(updated_session.rounds)
|
||||
session = updated_session
|
||||
yield {"event": "update", "data": json.dumps({
|
||||
"session_id": session_id,
|
||||
"status": updated_session.status,
|
||||
"rounds": [round.dict() for round in updated_session.rounds],
|
||||
"evidence_library": [e.dict() for e in updated_session.evidence_library],
|
||||
"current_round": len(updated_session.rounds)
|
||||
}, default=str)}
|
||||
|
||||
if updated_session.status in ("completed", "terminated"):
|
||||
yield {"event": "complete", "data": json.dumps({
|
||||
"session_id": session_id,
|
||||
"status": updated_session.status,
|
||||
"summary": updated_session.summary,
|
||||
"rounds": [round.dict() for round in updated_session.rounds],
|
||||
"evidence_library": [e.dict() for e in updated_session.evidence_library]
|
||||
}, default=str)}
|
||||
break
|
||||
|
||||
return EventSourceResponse(event_generator())
|
||||
|
||||
|
||||
@router.get("/sessions")
|
||||
async def list_sessions(db=Depends(get_db)) -> Dict[str, Any]:
|
||||
"""
|
||||
List all debate sessions
|
||||
"""
|
||||
try:
|
||||
session_manager = SessionManager()
|
||||
sessions = await session_manager.list_sessions(db)
|
||||
return {"sessions": sessions}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
173
api/models.py
Normal file
173
api/models.py
Normal file
@@ -0,0 +1,173 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
import aiohttp
|
||||
|
||||
from services.api_key_service import ApiKeyService
|
||||
from db_models import get_db
|
||||
from api.api_keys import validate_api_key
|
||||
|
||||
router = APIRouter(tags=["models"])
|
||||
|
||||
|
||||
@router.get("/models/{provider}")
|
||||
async def get_available_models(provider: str, db=Depends(get_db)):
|
||||
"""
|
||||
Get available models for a specific provider by fetching from their API.
|
||||
Falls back to a curated default list if the API key is missing or the request fails.
|
||||
"""
|
||||
# Curated defaults for each provider — used as fallback
|
||||
default_models = {
|
||||
"openai": [
|
||||
{"model_identifier": "gpt-4o", "display_name": "gpt-4o"},
|
||||
{"model_identifier": "gpt-4o-mini", "display_name": "gpt-4o-mini"},
|
||||
{"model_identifier": "gpt-4-turbo", "display_name": "gpt-4-turbo"},
|
||||
{"model_identifier": "gpt-4", "display_name": "gpt-4"},
|
||||
{"model_identifier": "o3-mini", "display_name": "o3-mini"},
|
||||
{"model_identifier": "gpt-3.5-turbo", "display_name": "gpt-3.5-turbo"}
|
||||
],
|
||||
"claude": [
|
||||
{"model_identifier": "claude-opus-4-5", "display_name": "claude-opus-4-5"},
|
||||
{"model_identifier": "claude-sonnet-4-5", "display_name": "claude-sonnet-4-5"},
|
||||
{"model_identifier": "claude-3-5-sonnet-20241022", "display_name": "claude-3-5-sonnet-20241022"},
|
||||
{"model_identifier": "claude-3-5-haiku-20241022", "display_name": "claude-3-5-haiku-20241022"},
|
||||
{"model_identifier": "claude-3-opus-20240229", "display_name": "claude-3-opus-20240229"}
|
||||
],
|
||||
"qwen": [
|
||||
{"model_identifier": "qwen3-max", "display_name": "qwen3-max"},
|
||||
{"model_identifier": "qwen3-plus", "display_name": "qwen3-plus"},
|
||||
{"model_identifier": "qwen3-flash", "display_name": "qwen3-flash"},
|
||||
{"model_identifier": "qwen-max", "display_name": "qwen-max"},
|
||||
{"model_identifier": "qwen-plus", "display_name": "qwen-plus"},
|
||||
{"model_identifier": "qwen-turbo", "display_name": "qwen-turbo"}
|
||||
],
|
||||
"deepseek": [
|
||||
{"model_identifier": "deepseek-chat", "display_name": "deepseek-chat"},
|
||||
{"model_identifier": "deepseek-reasoner", "display_name": "deepseek-reasoner"},
|
||||
{"model_identifier": "deepseek-v3", "display_name": "deepseek-v3"},
|
||||
{"model_identifier": "deepseek-r1", "display_name": "deepseek-r1"}
|
||||
]
|
||||
}
|
||||
|
||||
defaults = default_models.get(provider, [])
|
||||
|
||||
try:
|
||||
# Retrieve and decrypt API key
|
||||
decrypted_key = ApiKeyService.get_api_key(db, provider)
|
||||
if not decrypted_key:
|
||||
return {"provider": provider, "models": defaults}
|
||||
|
||||
# ---------- OpenAI ----------
|
||||
if provider == "openai":
|
||||
import openai
|
||||
async with openai.AsyncOpenAI(api_key=decrypted_key, timeout=10.0) as client:
|
||||
response = await client.models.list()
|
||||
|
||||
# Keep only chat / reasoning models, sorted newest-first by created timestamp
|
||||
chat_prefixes = ('gpt-', 'o1', 'o3', 'o4', 'chatgpt')
|
||||
models = []
|
||||
seen = set()
|
||||
for m in sorted(response.data, key=lambda x: x.created, reverse=True):
|
||||
if m.id not in seen and m.id.startswith(chat_prefixes):
|
||||
seen.add(m.id)
|
||||
models.append({"model_identifier": m.id, "display_name": m.id})
|
||||
|
||||
if models:
|
||||
return {"provider": provider, "models": models}
|
||||
|
||||
# ---------- Claude ----------
|
||||
elif provider == "claude":
|
||||
timeout = aiohttp.ClientTimeout(total=10)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
headers = {
|
||||
"x-api-key": decrypted_key,
|
||||
"anthropic-version": "2023-06-01"
|
||||
}
|
||||
async with session.get("https://api.anthropic.com/v1/models", headers=headers) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
models = []
|
||||
for m in data.get("data", []):
|
||||
model_id = m.get("id", "")
|
||||
if model_id.startswith("claude-"):
|
||||
models.append({"model_identifier": model_id, "display_name": model_id})
|
||||
if models:
|
||||
return {"provider": provider, "models": models}
|
||||
else:
|
||||
print(f"Claude models API returned {resp.status}: {await resp.text()}")
|
||||
|
||||
# ---------- Qwen ----------
|
||||
elif provider == "qwen":
|
||||
timeout = aiohttp.ClientTimeout(total=10)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
headers = {"Authorization": f"Bearer {decrypted_key}"}
|
||||
async with session.get("https://dashscope.aliyuncs.com/compatible-mode/v1/models", headers=headers) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
exclude_keywords = [
|
||||
'tts', 'vl', 'ocr', 'image', 'asr', '-mt-', '-mt',
|
||||
'math', 'embed', 'rerank', 'coder', 'translate',
|
||||
's2s', 'deep-search', 'omni', 'gui-'
|
||||
]
|
||||
models = []
|
||||
for m in data.get("data", []):
|
||||
model_id = m.get("id", "")
|
||||
if not model_id:
|
||||
continue
|
||||
if not (model_id.startswith("qwen") or model_id.startswith("qwq")):
|
||||
continue
|
||||
if any(kw in model_id for kw in exclude_keywords):
|
||||
continue
|
||||
models.append({"model_identifier": model_id, "display_name": model_id})
|
||||
if models:
|
||||
return {"provider": provider, "models": models}
|
||||
else:
|
||||
print(f"Qwen models API returned {resp.status}: {await resp.text()}")
|
||||
|
||||
# ---------- DeepSeek ----------
|
||||
elif provider == "deepseek":
|
||||
timeout = aiohttp.ClientTimeout(total=10)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
headers = {"Authorization": f"Bearer {decrypted_key}"}
|
||||
async with session.get("https://api.deepseek.com/v1/models", headers=headers) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
models = []
|
||||
for m in data.get("data", []):
|
||||
model_id = m.get("id", "")
|
||||
if model_id.startswith("deepseek"):
|
||||
models.append({"model_identifier": model_id, "display_name": model_id})
|
||||
if models:
|
||||
return {"provider": provider, "models": models}
|
||||
else:
|
||||
print(f"DeepSeek models API returned {resp.status}: {await resp.text()}")
|
||||
|
||||
# API fetch succeeded but returned empty list, or unknown provider — use defaults
|
||||
return {"provider": provider, "models": defaults}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error fetching models for {provider}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return {"provider": provider, "models": defaults}
|
||||
|
||||
|
||||
@router.get("/providers")
|
||||
async def get_available_providers(db=Depends(get_db)):
|
||||
"""
|
||||
Get all providers that have valid API keys set
|
||||
"""
|
||||
try:
|
||||
available_providers = []
|
||||
for provider in ("openai", "claude", "qwen", "deepseek", "tavily"):
|
||||
decrypted_key = ApiKeyService.get_api_key(db, provider)
|
||||
if not decrypted_key:
|
||||
continue
|
||||
is_valid = await validate_api_key(provider, decrypted_key)
|
||||
if is_valid:
|
||||
available_providers.append({
|
||||
"provider": provider,
|
||||
"has_valid_key": True
|
||||
})
|
||||
|
||||
return {"providers": available_providers}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
192
api/setup.py
Normal file
192
api/setup.py
Normal file
@@ -0,0 +1,192 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from pydantic import BaseModel
|
||||
|
||||
from services.config_service import ConfigService
|
||||
|
||||
router = APIRouter(prefix="/api/setup", tags=["setup"])
|
||||
|
||||
PASSWORD_PLACEHOLDER = "********"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request / response models
|
||||
# ---------------------------------------------------------------------------
|
||||
class DatabaseConfig(BaseModel):
|
||||
host: str
|
||||
port: int = 3306
|
||||
user: str = "root"
|
||||
password: str = ""
|
||||
database: str = "dialectica"
|
||||
|
||||
|
||||
class KeycloakConfig(BaseModel):
|
||||
host: str = ""
|
||||
realm: str = ""
|
||||
client_id: str = ""
|
||||
|
||||
|
||||
class TlsConfig(BaseModel):
|
||||
cert_path: str = ""
|
||||
key_path: str = ""
|
||||
force_https: bool = False
|
||||
|
||||
|
||||
class FullConfig(BaseModel):
|
||||
database: Optional[DatabaseConfig] = None
|
||||
keycloak: Optional[KeycloakConfig] = None
|
||||
tls: Optional[TlsConfig] = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Access control dependency
|
||||
# ---------------------------------------------------------------------------
|
||||
async def setup_guard(request: Request):
|
||||
"""Three-phase access control for setup routes.
|
||||
|
||||
1. Not initialized → only localhost allowed
|
||||
2. ENV_MODE=dev → open
|
||||
3. ENV_MODE=prod → Keycloak admin JWT required
|
||||
"""
|
||||
config = ConfigService.load()
|
||||
initialized = config.get("initialized", False)
|
||||
env_mode = os.getenv("ENV_MODE", "dev")
|
||||
|
||||
if env_mode == "dev":
|
||||
return # dev mode: no auth needed, even before initialisation
|
||||
|
||||
if not initialized:
|
||||
# prod + not initialised: only localhost may configure
|
||||
client_ip = request.client.host
|
||||
if client_ip not in ("127.0.0.1", "::1"):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="初次设置仅允许从本机访问",
|
||||
)
|
||||
return
|
||||
|
||||
# prod → delegate to Keycloak middleware (Phase 3)
|
||||
from app.middleware.auth import require_admin
|
||||
await require_admin(request, config)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Routes
|
||||
# ---------------------------------------------------------------------------
|
||||
@router.get("/status")
|
||||
async def setup_status():
|
||||
"""Return current system initialisation state, including KC info for OIDC."""
|
||||
config = ConfigService.load()
|
||||
env_mode = os.getenv("ENV_MODE", "dev")
|
||||
|
||||
result = {
|
||||
"initialized": config.get("initialized", False),
|
||||
"env_mode": env_mode,
|
||||
"db_configured": ConfigService.is_db_configured(),
|
||||
}
|
||||
|
||||
# Include Keycloak info so the frontend can build OIDC config
|
||||
kc = config.get("keycloak", {})
|
||||
if env_mode == "prod" and kc.get("host") and kc.get("realm"):
|
||||
result["keycloak"] = {
|
||||
"authority": f"{kc['host']}/realms/{kc['realm']}",
|
||||
"client_id": kc.get("client_id", ""),
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/config", dependencies=[Depends(setup_guard)])
|
||||
async def get_config():
|
||||
"""Return full config with passwords replaced by placeholder."""
|
||||
config = ConfigService.load()
|
||||
if "database" in config and config["database"].get("password"):
|
||||
config["database"]["password"] = PASSWORD_PLACEHOLDER
|
||||
return config
|
||||
|
||||
|
||||
@router.put("/config", dependencies=[Depends(setup_guard)])
|
||||
async def update_config(payload: FullConfig):
|
||||
"""Merge submitted config sections into the YAML file."""
|
||||
config = ConfigService.load()
|
||||
|
||||
if payload.database is not None:
|
||||
db_dict = payload.database.model_dump()
|
||||
# If password is the placeholder, keep the existing real password
|
||||
if db_dict.get("password") == PASSWORD_PLACEHOLDER:
|
||||
db_dict["password"] = config.get("database", {}).get("password", "")
|
||||
config["database"] = db_dict
|
||||
if payload.keycloak is not None:
|
||||
config["keycloak"] = payload.keycloak.model_dump()
|
||||
if payload.tls is not None:
|
||||
config["tls"] = payload.tls.model_dump()
|
||||
|
||||
ConfigService.save(config)
|
||||
return {"message": "配置已保存"}
|
||||
|
||||
|
||||
@router.post("/test-db", dependencies=[Depends(setup_guard)])
|
||||
async def test_db_connection(db_config: DatabaseConfig):
|
||||
"""Test a database connection with the provided parameters (no save)."""
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
password = db_config.password
|
||||
# If password is the placeholder, use the real password from config
|
||||
if password == PASSWORD_PLACEHOLDER:
|
||||
password = ConfigService.load().get("database", {}).get("password", "")
|
||||
|
||||
url = (
|
||||
f"mysql+pymysql://{quote_plus(db_config.user)}:{quote_plus(password)}"
|
||||
f"@{db_config.host}:{db_config.port}/{db_config.database}"
|
||||
)
|
||||
try:
|
||||
engine = create_engine(url, pool_pre_ping=True)
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text("SELECT 1"))
|
||||
engine.dispose()
|
||||
return {"success": True, "message": "数据库连接成功"}
|
||||
except Exception as e:
|
||||
return {"success": False, "message": str(e)}
|
||||
|
||||
|
||||
@router.post("/test-keycloak", dependencies=[Depends(setup_guard)])
|
||||
async def test_keycloak(kc_config: KeycloakConfig):
|
||||
"""Test Keycloak connectivity by fetching the OIDC discovery document."""
|
||||
import httpx
|
||||
|
||||
well_known = (
|
||||
f"{kc_config.host}/realms/{kc_config.realm}"
|
||||
f"/.well-known/openid-configuration"
|
||||
)
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
resp = await client.get(well_known)
|
||||
if resp.status_code == 200:
|
||||
return {"success": True, "message": "Keycloak 连通正常"}
|
||||
return {"success": False, "message": f"HTTP {resp.status_code}"}
|
||||
except Exception as e:
|
||||
return {"success": False, "message": str(e)}
|
||||
|
||||
|
||||
@router.post("/initialize", dependencies=[Depends(setup_guard)])
|
||||
async def initialize():
|
||||
"""Mark system as initialised and reload DB connection."""
|
||||
config = ConfigService.load()
|
||||
|
||||
if not ConfigService.is_db_configured():
|
||||
raise HTTPException(status_code=400, detail="请先配置数据库连接")
|
||||
|
||||
# Reload DB engine so business routes can start working
|
||||
from app.db_models import reload_db_connection
|
||||
from app.storage.database import init_db
|
||||
|
||||
reload_db_connection()
|
||||
init_db()
|
||||
|
||||
config["initialized"] = True
|
||||
ConfigService.save(config)
|
||||
|
||||
return {"message": "系统初始化完成", "initialized": True}
|
||||
59
app.py
Normal file
59
app.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from starlette.responses import JSONResponse
|
||||
import uvicorn
|
||||
|
||||
from storage.database import init_db
|
||||
from exceptions import ServiceNotConfiguredError
|
||||
from middleware.config_guard import ConfigGuardMiddleware
|
||||
from api import debates_router, api_keys_router, models_router, setup_router
|
||||
|
||||
app = FastAPI(
|
||||
title="Dialectica - Multi-Model Debate Framework",
|
||||
description="A framework for structured debates between multiple language models",
|
||||
version="0.1.0"
|
||||
)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # Allow all origins for development
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"], # Allow all methods
|
||||
allow_headers=["*"], # Allow all headers
|
||||
# Add exposed headers to allow frontend to access response headers
|
||||
expose_headers=["Access-Control-Allow-Origin", "Access-Control-Allow-Credentials"]
|
||||
)
|
||||
|
||||
# Config guard: return 503 for business routes when DB not configured
|
||||
app.add_middleware(ConfigGuardMiddleware)
|
||||
|
||||
|
||||
@app.exception_handler(ServiceNotConfiguredError)
|
||||
async def not_configured_handler(request, exc):
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content={"error_code": "SERVICE_NOT_CONFIGURED", "detail": str(exc)},
|
||||
)
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
def startup_event():
|
||||
"""Initialize database on startup (skipped if not configured)."""
|
||||
init_db()
|
||||
|
||||
|
||||
# Register routers
|
||||
app.include_router(debates_router)
|
||||
app.include_router(api_keys_router)
|
||||
app.include_router(models_router)
|
||||
app.include_router(setup_router)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "Welcome to Dialectica - Multi-Model Debate Framework"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="0.0.0.0", port=8020)
|
||||
79
db_models/__init__.py
Normal file
79
db_models/__init__.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from sqlalchemy import Column, Integer, String, Boolean, DateTime, Text
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from datetime import datetime
|
||||
|
||||
from exceptions import ServiceNotConfiguredError
|
||||
from services.config_service import ConfigService
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class ApiKey(Base):
|
||||
__tablename__ = "api_keys"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
provider = Column(String(50), unique=True, index=True, nullable=False)
|
||||
api_key_encrypted = Column(Text, nullable=False) # Encrypted API key
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
|
||||
class ModelConfig(Base):
|
||||
__tablename__ = "model_configs"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
provider = Column(String(50), nullable=False)
|
||||
model_name = Column(String(100), nullable=False)
|
||||
display_name = Column(String(100)) # Optional display name
|
||||
is_active = Column(Boolean, default=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
__table_args__ = {'mysql_charset': 'utf8mb4'}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Lazy engine / session factory — only created when first requested
|
||||
# ---------------------------------------------------------------------------
|
||||
_engine = None
|
||||
_SessionLocal = None
|
||||
|
||||
|
||||
def _get_engine():
|
||||
from sqlalchemy import create_engine
|
||||
global _engine
|
||||
if _engine is None:
|
||||
db_url = ConfigService.get_database_url()
|
||||
if not db_url:
|
||||
raise ServiceNotConfiguredError("数据库未配置")
|
||||
_engine = create_engine(db_url)
|
||||
return _engine
|
||||
|
||||
|
||||
def _get_session_factory():
|
||||
global _SessionLocal
|
||||
if _SessionLocal is None:
|
||||
_SessionLocal = sessionmaker(
|
||||
autocommit=False, autoflush=False, bind=_get_engine()
|
||||
)
|
||||
return _SessionLocal
|
||||
|
||||
|
||||
def reload_db_connection():
|
||||
"""Dispose current engine and reset so next call rebuilds from config."""
|
||||
global _engine, _SessionLocal
|
||||
if _engine is not None:
|
||||
_engine.dispose()
|
||||
_engine = None
|
||||
_SessionLocal = None
|
||||
|
||||
|
||||
def get_db():
|
||||
"""FastAPI dependency that yields a DB session."""
|
||||
session_factory = _get_session_factory()
|
||||
db = session_factory()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
3
exceptions/__init__.py
Normal file
3
exceptions/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
class ServiceNotConfiguredError(Exception):
|
||||
"""Raised when a required service (database, etc.) is not configured."""
|
||||
pass
|
||||
0
middleware/__init__.py
Normal file
0
middleware/__init__.py
Normal file
114
middleware/auth.py
Normal file
114
middleware/auth.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""Keycloak JWT authentication middleware."""
|
||||
|
||||
import os
|
||||
from fastapi import HTTPException, Request
|
||||
from jose import jwt, JWTError, jwk
|
||||
from jose.utils import base64url_decode
|
||||
import httpx
|
||||
|
||||
# Cache JWKS per (host, realm) to avoid fetching on every request
|
||||
_jwks_cache: dict[str, dict] = {}
|
||||
|
||||
|
||||
async def _get_jwks(kc_host: str, realm: str) -> dict:
|
||||
cache_key = f"{kc_host}/{realm}"
|
||||
if cache_key in _jwks_cache:
|
||||
return _jwks_cache[cache_key]
|
||||
|
||||
url = f"{kc_host}/realms/{realm}/protocol/openid-connect/certs"
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
resp = await client.get(url)
|
||||
if resp.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"无法获取 Keycloak JWKS: HTTP {resp.status_code}",
|
||||
)
|
||||
data = resp.json()
|
||||
_jwks_cache[cache_key] = data
|
||||
return data
|
||||
|
||||
|
||||
def _find_rsa_key(jwks: dict, token: str) -> dict | None:
|
||||
"""Find the matching RSA key from JWKS for the token's kid."""
|
||||
unverified_header = jwt.get_unverified_header(token)
|
||||
kid = unverified_header.get("kid")
|
||||
for key in jwks.get("keys", []):
|
||||
if key.get("kid") == kid:
|
||||
return key
|
||||
return None
|
||||
|
||||
|
||||
async def verify_token(request: Request, kc_host: str, realm: str) -> dict:
|
||||
"""Extract and verify the Bearer JWT from the Authorization header.
|
||||
|
||||
Returns the decoded payload on success.
|
||||
Raises HTTPException(401) on missing/invalid token.
|
||||
"""
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if not auth_header.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="缺少 Authorization Bearer token")
|
||||
|
||||
token = auth_header[7:]
|
||||
|
||||
jwks = await _get_jwks(kc_host, realm)
|
||||
rsa_key = _find_rsa_key(jwks, token)
|
||||
if rsa_key is None:
|
||||
# Clear cache in case keys rotated
|
||||
_jwks_cache.pop(f"{kc_host}/{realm}", None)
|
||||
jwks = await _get_jwks(kc_host, realm)
|
||||
rsa_key = _find_rsa_key(jwks, token)
|
||||
if rsa_key is None:
|
||||
raise HTTPException(status_code=401, detail="无法匹配 JWT 签名密钥")
|
||||
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
rsa_key,
|
||||
algorithms=["RS256"],
|
||||
options={"verify_aud": False}, # Keycloak audience varies by client
|
||||
)
|
||||
return payload
|
||||
except JWTError as e:
|
||||
raise HTTPException(status_code=401, detail=f"JWT 验证失败: {e}")
|
||||
|
||||
|
||||
async def require_auth(request: Request):
|
||||
"""Verify Bearer JWT for write endpoints.
|
||||
|
||||
Dev mode: passthrough (no auth required).
|
||||
Prod mode: validates JWT via Keycloak JWKS.
|
||||
"""
|
||||
if os.getenv("ENV_MODE", "dev") == "dev":
|
||||
return None
|
||||
|
||||
from app.services.config_service import ConfigService
|
||||
config = ConfigService.load()
|
||||
kc = config.get("keycloak", {})
|
||||
if not kc.get("host"):
|
||||
return None # KC not configured – allow access
|
||||
|
||||
return await verify_token(request, kc["host"], kc.get("realm", ""))
|
||||
|
||||
|
||||
async def require_admin(request: Request, config: dict):
|
||||
"""Verify the request carries a valid Keycloak JWT with admin role.
|
||||
|
||||
Raises HTTPException(401/403) on failure.
|
||||
"""
|
||||
kc = config.get("keycloak", {})
|
||||
kc_host = kc.get("host")
|
||||
realm = kc.get("realm")
|
||||
|
||||
if not kc_host or not realm:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Keycloak 未配置,无法进行鉴权",
|
||||
)
|
||||
|
||||
payload = await verify_token(request, kc_host, realm)
|
||||
|
||||
roles = payload.get("realm_access", {}).get("roles", [])
|
||||
if "admin" not in roles:
|
||||
raise HTTPException(status_code=403, detail="需要 admin 角色")
|
||||
|
||||
return payload
|
||||
32
middleware/config_guard.py
Normal file
32
middleware/config_guard.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from services.config_service import ConfigService
|
||||
|
||||
# Paths that are always accessible, even when DB is not configured
|
||||
_ALLOWED_PREFIXES = ("/api/setup", "/docs", "/openapi", "/redoc")
|
||||
|
||||
|
||||
class ConfigGuardMiddleware(BaseHTTPMiddleware):
|
||||
"""Return 503 for all business routes when the database is not configured."""
|
||||
|
||||
async def dispatch(self, request, call_next):
|
||||
path = request.url.path
|
||||
|
||||
# Always allow: setup routes, root, docs, OPTIONS (CORS preflight)
|
||||
if path == "/" or request.method == "OPTIONS":
|
||||
return await call_next(request)
|
||||
for prefix in _ALLOWED_PREFIXES:
|
||||
if path.startswith(prefix):
|
||||
return await call_next(request)
|
||||
|
||||
if not ConfigService.is_db_configured():
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content={
|
||||
"error_code": "SERVICE_NOT_CONFIGURED",
|
||||
"detail": "数据库未配置,请先完成系统初始化",
|
||||
},
|
||||
)
|
||||
|
||||
return await call_next(request)
|
||||
0
models/__init__.py
Normal file
0
models/__init__.py
Normal file
92
models/debate.py
Normal file
92
models/debate.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
from enum import Enum
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class ModelProvider(str, Enum):
|
||||
OPENAI = "openai"
|
||||
DEEPSEEK = "deepseek"
|
||||
QWEN = "qwen"
|
||||
CLAUDE = "claude"
|
||||
|
||||
|
||||
class DebateStance(str, Enum):
|
||||
PRO = "pro"
|
||||
CON = "con"
|
||||
|
||||
|
||||
class SearchResult(BaseModel):
|
||||
title: str
|
||||
url: str
|
||||
snippet: str
|
||||
score: Optional[float] = None
|
||||
|
||||
|
||||
class SearchEvidence(BaseModel):
|
||||
query: str
|
||||
results: List[SearchResult]
|
||||
mode: str # "auto", "tool", "both"
|
||||
|
||||
|
||||
class DebateRound(BaseModel):
|
||||
round_number: int
|
||||
speaker: str # Model identifier
|
||||
stance: DebateStance
|
||||
content: str
|
||||
timestamp: datetime
|
||||
token_count: Optional[int] = None
|
||||
search_evidence: Optional[SearchEvidence] = None
|
||||
|
||||
|
||||
class DebateParticipant(BaseModel):
|
||||
model_config = {"protected_namespaces": ()}
|
||||
|
||||
model_identifier: str
|
||||
provider: ModelProvider
|
||||
stance: DebateStance
|
||||
api_key: Optional[str] = None
|
||||
|
||||
|
||||
class DebateConstraints(BaseModel):
|
||||
max_rounds: int = 5
|
||||
max_tokens_per_turn: int = 500
|
||||
max_total_tokens: Optional[int] = None
|
||||
forbid_repetition: bool = True
|
||||
must_respond_to_opponent: bool = True
|
||||
web_search_enabled: bool = False
|
||||
web_search_mode: str = "auto" # "auto", "tool", "both"
|
||||
|
||||
|
||||
class DebateRequest(BaseModel):
|
||||
topic: str
|
||||
participants: List[DebateParticipant]
|
||||
constraints: DebateConstraints
|
||||
custom_system_prompt: Optional[str] = None
|
||||
|
||||
|
||||
class EvidenceReference(BaseModel):
|
||||
round_number: int
|
||||
speaker: str
|
||||
stance: DebateStance
|
||||
|
||||
|
||||
class EvidenceEntry(BaseModel):
|
||||
title: str
|
||||
url: str
|
||||
snippet: str
|
||||
score: Optional[float] = None
|
||||
references: List[EvidenceReference]
|
||||
|
||||
|
||||
class DebateSession(BaseModel):
|
||||
session_id: str
|
||||
topic: str
|
||||
participants: List[DebateParticipant]
|
||||
constraints: DebateConstraints
|
||||
rounds: List[DebateRound]
|
||||
status: str # "active", "completed", "terminated"
|
||||
created_at: datetime
|
||||
completed_at: Optional[datetime] = None
|
||||
summary: Optional[str] = None
|
||||
evidence_library: List[EvidenceEntry] = []
|
||||
0
orchestrator/__init__.py
Normal file
0
orchestrator/__init__.py
Normal file
358
orchestrator/debate_orchestrator.py
Normal file
358
orchestrator/debate_orchestrator.py
Normal file
@@ -0,0 +1,358 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import Dict, List, Optional
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.debate import (
|
||||
DebateRequest, DebateSession, DebateRound,
|
||||
DebateStance, DebateParticipant, SearchEvidence, SearchResult,
|
||||
EvidenceEntry, EvidenceReference
|
||||
)
|
||||
from providers.provider_factory import ProviderFactory
|
||||
from storage.session_manager import SessionManager
|
||||
from utils.summarizer import summarize_debate
|
||||
from services.search_service import SearchService
|
||||
from services.api_key_service import ApiKeyService
|
||||
|
||||
|
||||
class DebateOrchestrator:
|
||||
"""
|
||||
Orchestrates the debate between multiple language models
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.session_manager = SessionManager()
|
||||
self.provider_factory = ProviderFactory()
|
||||
|
||||
async def create_session(self, debate_request: DebateRequest) -> str:
|
||||
"""
|
||||
Create a new debate session
|
||||
"""
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
session = DebateSession(
|
||||
session_id=session_id,
|
||||
topic=debate_request.topic,
|
||||
participants=debate_request.participants,
|
||||
constraints=debate_request.constraints,
|
||||
rounds=[],
|
||||
status="active",
|
||||
created_at=datetime.now()
|
||||
)
|
||||
|
||||
await self.session_manager.save_session(self.db, session)
|
||||
return session_id
|
||||
|
||||
def _get_search_service(self) -> Optional[SearchService]:
|
||||
"""
|
||||
Get a SearchService instance if Tavily API key is available.
|
||||
"""
|
||||
tavily_key = ApiKeyService.get_api_key(self.db, "tavily")
|
||||
if tavily_key:
|
||||
return SearchService(api_key=tavily_key)
|
||||
return None
|
||||
|
||||
async def run_debate(self, session_id: str) -> DebateSession:
|
||||
"""
|
||||
Run the complete debate process
|
||||
"""
|
||||
session = await self.session_manager.get_session(self.db, session_id)
|
||||
if not session:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
|
||||
# Initialize providers for each participant
|
||||
providers = {}
|
||||
for participant in session.participants:
|
||||
provider = self.provider_factory.create_provider(
|
||||
self.db,
|
||||
participant.provider,
|
||||
participant.api_key # This can be None, and the provider will fetch from DB
|
||||
)
|
||||
providers[participant.model_identifier] = provider
|
||||
|
||||
# Initialize search service if web search is enabled
|
||||
search_service = None
|
||||
web_search_enabled = session.constraints.web_search_enabled
|
||||
web_search_mode = session.constraints.web_search_mode
|
||||
if web_search_enabled:
|
||||
search_service = self._get_search_service()
|
||||
if not search_service:
|
||||
print("Warning: Web search enabled but no Tavily API key found. Disabling search.")
|
||||
web_search_enabled = False
|
||||
|
||||
# Run the debate rounds
|
||||
for round_num in range(session.constraints.max_rounds):
|
||||
if session.status != "active":
|
||||
break
|
||||
|
||||
# Alternate between participants
|
||||
current_participant = session.participants[round_num % len(session.participants)]
|
||||
provider = providers[current_participant.model_identifier]
|
||||
|
||||
# Perform automatic search if enabled
|
||||
search_evidence = None
|
||||
if web_search_enabled and web_search_mode in ("auto", "both"):
|
||||
search_evidence = self._perform_automatic_search(
|
||||
search_service, session, round_num
|
||||
)
|
||||
|
||||
# Prepare context for the current turn (with search results if available)
|
||||
context = self._prepare_context(session, current_participant.stance, search_evidence)
|
||||
|
||||
# Determine if we should use tool calling for this round
|
||||
use_tool_calling = (
|
||||
web_search_enabled
|
||||
and web_search_mode in ("tool", "both")
|
||||
and provider.supports_tools()
|
||||
)
|
||||
|
||||
if use_tool_calling:
|
||||
response, tool_evidence = await self._handle_tool_calls(
|
||||
provider, current_participant.model_identifier,
|
||||
context, search_service
|
||||
)
|
||||
# Merge tool-based evidence with auto evidence
|
||||
if tool_evidence:
|
||||
if search_evidence:
|
||||
search_evidence.results.extend(tool_evidence.results)
|
||||
search_evidence.query += f" | {tool_evidence.query}"
|
||||
search_evidence.mode = "both"
|
||||
else:
|
||||
search_evidence = tool_evidence
|
||||
else:
|
||||
response = await provider.generate_response(
|
||||
model=current_participant.model_identifier,
|
||||
prompt=context
|
||||
)
|
||||
|
||||
# Clean the response to remove any echoed prompt/meta text
|
||||
response = self._clean_response(response)
|
||||
|
||||
# Create a new round
|
||||
round_data = DebateRound(
|
||||
round_number=round_num + 1,
|
||||
speaker=current_participant.model_identifier,
|
||||
stance=current_participant.stance,
|
||||
content=response,
|
||||
timestamp=datetime.now(),
|
||||
token_count=len(response.split()), # Approximate token count
|
||||
search_evidence=search_evidence
|
||||
)
|
||||
|
||||
session.rounds.append(round_data)
|
||||
|
||||
# Update evidence library with search results from this round
|
||||
if round_data.search_evidence:
|
||||
self._update_evidence_library(session, round_data)
|
||||
|
||||
# Update session in storage
|
||||
await self.session_manager.update_session(self.db, session)
|
||||
|
||||
# Small delay between rounds to simulate realistic interaction
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Generate summary after all rounds are complete
|
||||
summary = await summarize_debate(session)
|
||||
session.summary = summary
|
||||
session.status = "completed"
|
||||
session.completed_at = datetime.now()
|
||||
|
||||
await self.session_manager.update_session(self.db, session)
|
||||
return session
|
||||
|
||||
def _perform_automatic_search(
|
||||
self, search_service: SearchService, session: DebateSession, round_num: int
|
||||
) -> Optional[SearchEvidence]:
|
||||
"""
|
||||
Perform an automatic web search based on the topic and last opponent argument.
|
||||
"""
|
||||
last_opponent_arg = None
|
||||
if session.rounds:
|
||||
last_opponent_arg = session.rounds[-1].content
|
||||
|
||||
query = SearchService.generate_search_query(session.topic, last_opponent_arg)
|
||||
results = search_service.search(query, max_results=3)
|
||||
|
||||
if results:
|
||||
return SearchEvidence(
|
||||
query=query,
|
||||
results=results,
|
||||
mode="auto"
|
||||
)
|
||||
return None
|
||||
|
||||
async def _handle_tool_calls(
|
||||
self, provider, model: str, context: str, search_service: SearchService
|
||||
) -> tuple:
|
||||
"""
|
||||
Handle tool calling flow: send prompt with tools, execute any tool calls,
|
||||
then re-prompt with results. Max 2 tool call iterations.
|
||||
Returns (final_response_text, SearchEvidence_or_None).
|
||||
"""
|
||||
tools = [SearchService.get_tool_definition()]
|
||||
all_search_results = []
|
||||
all_queries = []
|
||||
|
||||
text, tool_calls = await provider.generate_response_with_tools(
|
||||
model=model,
|
||||
prompt=context,
|
||||
tools=tools,
|
||||
max_tokens=500
|
||||
)
|
||||
|
||||
# If no tool calls, return the text response directly
|
||||
if not tool_calls:
|
||||
return text, None
|
||||
|
||||
# Process up to 2 rounds of tool calls
|
||||
for iteration in range(2):
|
||||
if not tool_calls:
|
||||
break
|
||||
|
||||
# Execute each tool call
|
||||
tool_results_text = []
|
||||
for tc in tool_calls:
|
||||
if tc["name"] == "web_search":
|
||||
query = tc["arguments"].get("query", "")
|
||||
all_queries.append(query)
|
||||
results = search_service.search(query, max_results=3)
|
||||
all_search_results.extend(results)
|
||||
|
||||
# Format results for the model
|
||||
evidence = SearchEvidence(query=query, results=results, mode="tool")
|
||||
tool_results_text.append(SearchService.format_results_for_context(evidence))
|
||||
|
||||
# Re-prompt the model with tool results
|
||||
augmented_context = context + "\n" + "\n".join(tool_results_text)
|
||||
augmented_context += "\n请基于以上搜索结果和辩论历史,给出你的论点。"
|
||||
|
||||
text, tool_calls = await provider.generate_response_with_tools(
|
||||
model=model,
|
||||
prompt=augmented_context,
|
||||
tools=tools,
|
||||
max_tokens=500
|
||||
)
|
||||
|
||||
# Build combined evidence
|
||||
evidence = None
|
||||
if all_search_results:
|
||||
evidence = SearchEvidence(
|
||||
query=" | ".join(all_queries),
|
||||
results=all_search_results,
|
||||
mode="tool"
|
||||
)
|
||||
|
||||
# If we still got no text (model keeps calling tools), fall back
|
||||
if not text:
|
||||
text = await provider.generate_response(model=model, prompt=context, max_tokens=500)
|
||||
|
||||
return text, evidence
|
||||
|
||||
def _update_evidence_library(self, session: DebateSession, round_data: DebateRound):
|
||||
"""
|
||||
Merge search results from a round into the session's evidence library, deduplicating by URL.
|
||||
"""
|
||||
ref = EvidenceReference(
|
||||
round_number=round_data.round_number,
|
||||
speaker=round_data.speaker,
|
||||
stance=round_data.stance
|
||||
)
|
||||
|
||||
url_index = {entry.url: i for i, entry in enumerate(session.evidence_library)}
|
||||
|
||||
for result in round_data.search_evidence.results:
|
||||
if result.url in url_index:
|
||||
entry = session.evidence_library[url_index[result.url]]
|
||||
# Avoid duplicate references (same round + speaker)
|
||||
if not any(
|
||||
r.round_number == ref.round_number and r.speaker == ref.speaker
|
||||
for r in entry.references
|
||||
):
|
||||
entry.references.append(ref)
|
||||
else:
|
||||
new_entry = EvidenceEntry(
|
||||
title=result.title,
|
||||
url=result.url,
|
||||
snippet=result.snippet,
|
||||
score=result.score,
|
||||
references=[ref]
|
||||
)
|
||||
session.evidence_library.append(new_entry)
|
||||
url_index[result.url] = len(session.evidence_library) - 1
|
||||
|
||||
def _prepare_context(
|
||||
self, session: DebateSession, current_stance: DebateStance,
|
||||
search_evidence: Optional[SearchEvidence] = None
|
||||
) -> str:
|
||||
"""
|
||||
Prepare the context/prompt for the current model turn
|
||||
"""
|
||||
# Determine the stance of the current speaker
|
||||
if current_stance == DebateStance.PRO:
|
||||
position_desc = "正方(支持方)"
|
||||
opposing_desc = "反方(反对方)"
|
||||
else:
|
||||
position_desc = "反方(反对方)"
|
||||
opposing_desc = "正方(支持方)"
|
||||
|
||||
# Build the context with previous rounds
|
||||
context_parts = [
|
||||
f"辩论主题: {session.topic}",
|
||||
f"你的立场: {position_desc}",
|
||||
"辩论规则:",
|
||||
"- 必须回应对方上一轮的核心论点",
|
||||
"- 不得重复自己已提出的观点",
|
||||
"- 输出长度限制在合理范围内",
|
||||
"\n历史辩论记录:"
|
||||
]
|
||||
|
||||
for round_data in session.rounds:
|
||||
stance_text = "正方" if round_data.stance == DebateStance.PRO else "反方"
|
||||
context_parts.append(f"第{round_data.round_number}轮 - {stance_text}: {round_data.content}")
|
||||
|
||||
# Inject search results if available
|
||||
if search_evidence:
|
||||
context_parts.append(SearchService.format_results_for_context(search_evidence))
|
||||
|
||||
context_parts.append(f"\n现在轮到你 ({position_desc}) 发言,请基于以上内容进行回应。注意:直接给出你的论点内容,不要重复上述提示词、辩论规则或历史记录。")
|
||||
|
||||
return "\n".join(context_parts)
|
||||
|
||||
def _clean_response(self, response: str) -> str:
|
||||
"""
|
||||
Clean the model response to remove any echoed prompt/meta text
|
||||
"""
|
||||
import re
|
||||
|
||||
# Remove common prompt echoes and meta prefixes
|
||||
patterns_to_remove = [
|
||||
r'^第\d+轮\s*[-::]\s*(正方|反方)\s*[::]?\s*', # 第X轮 - 正方/反方:
|
||||
r'^(正方|反方)\s*[((][^))]*[))]\s*[::]?\s*', # 正方(支持方):
|
||||
r'^(正方|反方)\s*[::]\s*', # 正方: or 反方:
|
||||
r'^我的立场\s*[::]\s*', # 我的立场:
|
||||
r'^回应\s*[::]\s*', # 回应:
|
||||
r'^辩论发言\s*[::]\s*', # 辩论发言:
|
||||
]
|
||||
|
||||
cleaned = response.strip()
|
||||
for pattern in patterns_to_remove:
|
||||
cleaned = re.sub(pattern, '', cleaned, flags=re.MULTILINE)
|
||||
|
||||
return cleaned.strip()
|
||||
|
||||
async def get_session_status(self, session_id: str) -> Optional[DebateSession]:
|
||||
"""
|
||||
Get the current status of a debate session
|
||||
"""
|
||||
return await self.session_manager.get_session(self.db, session_id)
|
||||
|
||||
async def terminate_session(self, session_id: str):
|
||||
"""
|
||||
Terminate a debate session prematurely
|
||||
"""
|
||||
session = await self.session_manager.get_session(self.db, session_id)
|
||||
if session:
|
||||
session.status = "terminated"
|
||||
await self.session_manager.update_session(self.db, session)
|
||||
0
providers/__init__.py
Normal file
0
providers/__init__.py
Normal file
73
providers/base_provider.py
Normal file
73
providers/base_provider.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, List, Dict, Any, Tuple
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
class LLMProvider(ABC):
|
||||
"""
|
||||
Abstract base class for LLM providers
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, db: Session, api_key: Optional[str] = None):
|
||||
"""
|
||||
Initialize the provider with a database session and optional API key
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def generate_response(self, model: str, prompt: str, max_tokens: Optional[int] = None) -> str:
|
||||
"""
|
||||
Generate a response from the LLM
|
||||
"""
|
||||
pass
|
||||
|
||||
def supports_tools(self) -> bool:
|
||||
"""
|
||||
Whether this provider supports function/tool calling.
|
||||
Override in subclasses that support it.
|
||||
"""
|
||||
return False
|
||||
|
||||
async def generate_response_with_tools(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
tools: List[Dict[str, Any]],
|
||||
max_tokens: Optional[int] = None
|
||||
) -> Tuple[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Generate a response with tool definitions available.
|
||||
Returns (text_content, tool_calls) where tool_calls is a list of
|
||||
dicts with keys: name, arguments (dict).
|
||||
If no tools are called, tool_calls is empty and text_content has the response.
|
||||
Default implementation falls back to regular generation.
|
||||
"""
|
||||
text = await self.generate_response(model, prompt, max_tokens)
|
||||
return text, []
|
||||
|
||||
|
||||
class ProviderFactory:
|
||||
"""
|
||||
Factory class to create provider instances
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create_provider(db: Session, provider_type: str, api_key: Optional[str] = None):
|
||||
"""
|
||||
Create a provider instance based on the type
|
||||
"""
|
||||
if provider_type.value == "openai":
|
||||
from app.providers.openai_provider import OpenAIProvider
|
||||
return OpenAIProvider(db, api_key)
|
||||
elif provider_type.value == "claude":
|
||||
from app.providers.claude_provider import ClaudeProvider
|
||||
return ClaudeProvider(db, api_key)
|
||||
elif provider_type.value == "qwen":
|
||||
from app.providers.qwen_provider import QwenProvider
|
||||
return QwenProvider(db, api_key)
|
||||
elif provider_type.value == "deepseek":
|
||||
from app.providers.deepseek_provider import DeepSeekProvider
|
||||
return DeepSeekProvider(db, api_key)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider type: {provider_type}")
|
||||
85
providers/claude_provider.py
Normal file
85
providers/claude_provider.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import anthropic
|
||||
from typing import Optional, List, Dict, Any, Tuple
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from providers.base_provider import LLMProvider
|
||||
from services.api_key_service import ApiKeyService
|
||||
|
||||
SYSTEM_PROMPT = "你正在参与一场结构化辩论。请按照用户消息中的规则进行辩论,直接给出你的论点,不要重复提示词或历史记录。"
|
||||
|
||||
|
||||
class ClaudeProvider(LLMProvider):
|
||||
"""
|
||||
Anthropic Claude API provider implementation
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session, api_key: Optional[str] = None):
|
||||
if not api_key:
|
||||
api_key = ApiKeyService.get_api_key(db, "claude")
|
||||
|
||||
if api_key:
|
||||
self.client = anthropic.AsyncAnthropic(api_key=api_key)
|
||||
else:
|
||||
raise ValueError("Claude API key not found in database or provided")
|
||||
|
||||
def supports_tools(self) -> bool:
|
||||
return True
|
||||
|
||||
async def generate_response(self, model: str, prompt: str, max_tokens: Optional[int] = None) -> str:
|
||||
try:
|
||||
response = await self.client.messages.create(
|
||||
model=model,
|
||||
max_tokens=max_tokens or 500,
|
||||
temperature=0.7,
|
||||
system=SYSTEM_PROMPT,
|
||||
messages=[
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
)
|
||||
return response.content[0].text
|
||||
except Exception as e:
|
||||
raise Exception(f"Error calling Claude API: {str(e)}")
|
||||
|
||||
async def generate_response_with_tools(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
tools: List[Dict[str, Any]],
|
||||
max_tokens: Optional[int] = None
|
||||
) -> Tuple[str, List[Dict[str, Any]]]:
|
||||
try:
|
||||
# Convert OpenAI-format tools to Anthropic format
|
||||
anthropic_tools = []
|
||||
for tool in tools:
|
||||
func = tool.get("function", tool)
|
||||
anthropic_tools.append({
|
||||
"name": func["name"],
|
||||
"description": func.get("description", ""),
|
||||
"input_schema": func.get("parameters", func.get("input_schema", {}))
|
||||
})
|
||||
|
||||
response = await self.client.messages.create(
|
||||
model=model,
|
||||
max_tokens=max_tokens or 500,
|
||||
temperature=0.7,
|
||||
system=SYSTEM_PROMPT,
|
||||
messages=[
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
tools=anthropic_tools
|
||||
)
|
||||
|
||||
text_content = ""
|
||||
tool_calls = []
|
||||
for block in response.content:
|
||||
if block.type == "text":
|
||||
text_content += block.text
|
||||
elif block.type == "tool_use":
|
||||
tool_calls.append({
|
||||
"name": block.name,
|
||||
"arguments": block.input
|
||||
})
|
||||
|
||||
return text_content.strip(), tool_calls
|
||||
except Exception as e:
|
||||
raise Exception(f"Error calling Claude API with tools: {str(e)}")
|
||||
64
providers/deepseek_provider.py
Normal file
64
providers/deepseek_provider.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from typing import Optional
|
||||
from sqlalchemy.orm import Session
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from providers.base_provider import LLMProvider
|
||||
from services.api_key_service import ApiKeyService
|
||||
|
||||
|
||||
class DeepSeekProvider(LLMProvider):
|
||||
"""
|
||||
DeepSeek API provider implementation using OpenAI-compatible API
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session, api_key: Optional[str] = None):
|
||||
if not api_key:
|
||||
api_key = ApiKeyService.get_api_key(db, "deepseek")
|
||||
|
||||
if not api_key:
|
||||
raise ValueError("DeepSeek API key not found in database or provided")
|
||||
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url="https://api.deepseek.com"
|
||||
)
|
||||
|
||||
def supports_tools(self) -> bool:
|
||||
return False
|
||||
|
||||
async def generate_response(self, model: str, prompt: str, max_tokens: Optional[int] = None) -> str:
|
||||
try:
|
||||
is_reasoner = "reasoner" in model or "r1" in model.lower()
|
||||
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
# deepseek-reasoner 不支持 system message,把指令放进 user message
|
||||
if not is_reasoner:
|
||||
messages.insert(0, {
|
||||
"role": "system",
|
||||
"content": "你正在参与一场结构化辩论。请按照用户消息中的规则进行辩论,直接给出你的论点,不要重复提示词或历史记录。"
|
||||
})
|
||||
|
||||
kwargs = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
}
|
||||
# reasoner 模型不支持 max_tokens,使用 max_completion_tokens
|
||||
if is_reasoner:
|
||||
kwargs["max_completion_tokens"] = max_tokens or 4096
|
||||
else:
|
||||
kwargs["max_tokens"] = max_tokens or 500
|
||||
|
||||
response = await self.client.chat.completions.create(**kwargs)
|
||||
|
||||
message = response.choices[0].message
|
||||
content = message.content or ""
|
||||
|
||||
# deepseek-reasoner 的主要内容可能在 reasoning_content 中
|
||||
if not content.strip() and is_reasoner:
|
||||
reasoning = getattr(message, "reasoning_content", None)
|
||||
if reasoning:
|
||||
content = reasoning
|
||||
|
||||
return content.strip()
|
||||
except Exception as e:
|
||||
raise Exception(f"Error calling DeepSeek API: {str(e)}")
|
||||
72
providers/openai_provider.py
Normal file
72
providers/openai_provider.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import json
|
||||
import openai
|
||||
from typing import Optional, List, Dict, Any, Tuple
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from providers.base_provider import LLMProvider
|
||||
from services.api_key_service import ApiKeyService
|
||||
|
||||
SYSTEM_PROMPT = "你正在参与一场结构化辩论。请按照用户消息中的规则进行辩论,直接给出你的论点,不要重复提示词或历史记录。"
|
||||
|
||||
|
||||
class OpenAIProvider(LLMProvider):
|
||||
"""
|
||||
OpenAI API provider implementation
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session, api_key: Optional[str] = None):
|
||||
if not api_key:
|
||||
api_key = ApiKeyService.get_api_key(db, "openai")
|
||||
|
||||
if api_key:
|
||||
self.client = openai.AsyncOpenAI(api_key=api_key)
|
||||
else:
|
||||
raise ValueError("OpenAI API key not found in database or provided")
|
||||
|
||||
def supports_tools(self) -> bool:
|
||||
return True
|
||||
|
||||
async def generate_response(self, model: str, prompt: str, max_tokens: Optional[int] = None) -> str:
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
max_tokens=max_tokens or 500
|
||||
)
|
||||
return response.choices[0].message.content.strip()
|
||||
except Exception as e:
|
||||
raise Exception(f"Error calling OpenAI API: {str(e)}")
|
||||
|
||||
async def generate_response_with_tools(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
tools: List[Dict[str, Any]],
|
||||
max_tokens: Optional[int] = None
|
||||
) -> Tuple[str, List[Dict[str, Any]]]:
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
max_tokens=max_tokens or 500
|
||||
)
|
||||
message = response.choices[0].message
|
||||
text_content = message.content or ""
|
||||
tool_calls = []
|
||||
if message.tool_calls:
|
||||
for tc in message.tool_calls:
|
||||
tool_calls.append({
|
||||
"name": tc.function.name,
|
||||
"arguments": json.loads(tc.function.arguments)
|
||||
})
|
||||
return text_content.strip(), tool_calls
|
||||
except Exception as e:
|
||||
raise Exception(f"Error calling OpenAI API with tools: {str(e)}")
|
||||
49
providers/provider_factory.py
Normal file
49
providers/provider_factory.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
class LLMProvider(ABC):
|
||||
"""
|
||||
Abstract base class for LLM providers
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, db: Session, api_key: Optional[str] = None):
|
||||
"""
|
||||
Initialize the provider with a database session and optional API key
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def generate_response(self, model: str, prompt: str, max_tokens: Optional[int] = None) -> str:
|
||||
"""
|
||||
Generate a response from the LLM
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ProviderFactory:
|
||||
"""
|
||||
Factory class to create provider instances
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create_provider(db: Session, provider_type: str, api_key: Optional[str] = None):
|
||||
"""
|
||||
Create a provider instance based on the type
|
||||
"""
|
||||
if provider_type.value == "openai":
|
||||
from providers.openai_provider import OpenAIProvider
|
||||
return OpenAIProvider(db, api_key)
|
||||
elif provider_type.value == "claude":
|
||||
from providers.claude_provider import ClaudeProvider
|
||||
return ClaudeProvider(db, api_key)
|
||||
elif provider_type.value == "qwen":
|
||||
from providers.qwen_provider import QwenProvider
|
||||
return QwenProvider(db, api_key)
|
||||
elif provider_type.value == "deepseek":
|
||||
from providers.deepseek_provider import DeepSeekProvider
|
||||
return DeepSeekProvider(db, api_key)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider type: {provider_type}")
|
||||
75
providers/qwen_provider.py
Normal file
75
providers/qwen_provider.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import json
|
||||
from typing import Optional, List, Dict, Any, Tuple
|
||||
from sqlalchemy.orm import Session
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from providers.base_provider import LLMProvider
|
||||
from services.api_key_service import ApiKeyService
|
||||
|
||||
SYSTEM_PROMPT = "你正在参与一场结构化辩论。请按照用户消息中的规则进行辩论,直接给出你的论点,不要重复提示词或历史记录。"
|
||||
|
||||
|
||||
class QwenProvider(LLMProvider):
|
||||
"""
|
||||
Qwen API provider implementation using DashScope OpenAI-compatible API
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session, api_key: Optional[str] = None):
|
||||
if not api_key:
|
||||
api_key = ApiKeyService.get_api_key(db, "qwen")
|
||||
|
||||
if not api_key:
|
||||
raise ValueError("Qwen API key not found in database or provided")
|
||||
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
)
|
||||
|
||||
def supports_tools(self) -> bool:
|
||||
return True
|
||||
|
||||
async def generate_response(self, model: str, prompt: str, max_tokens: Optional[int] = None) -> str:
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
max_tokens=max_tokens or 500
|
||||
)
|
||||
return response.choices[0].message.content.strip()
|
||||
except Exception as e:
|
||||
raise Exception(f"Error calling Qwen API: {str(e)}")
|
||||
|
||||
async def generate_response_with_tools(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
tools: List[Dict[str, Any]],
|
||||
max_tokens: Optional[int] = None
|
||||
) -> Tuple[str, List[Dict[str, Any]]]:
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
max_tokens=max_tokens or 500
|
||||
)
|
||||
message = response.choices[0].message
|
||||
text_content = message.content or ""
|
||||
tool_calls = []
|
||||
if message.tool_calls:
|
||||
for tc in message.tool_calls:
|
||||
tool_calls.append({
|
||||
"name": tc.function.name,
|
||||
"arguments": json.loads(tc.function.arguments)
|
||||
})
|
||||
return text_content.strip(), tool_calls
|
||||
except Exception as e:
|
||||
raise Exception(f"Error calling Qwen API with tools: {str(e)}")
|
||||
19
requirements.txt
Normal file
19
requirements.txt
Normal file
@@ -0,0 +1,19 @@
|
||||
fastapi==0.115.0
|
||||
uvicorn[standard]==0.32.0
|
||||
pydantic==2.9.2
|
||||
pydantic-settings==2.6.1
|
||||
sqlalchemy==2.0.35
|
||||
aiosqlite==0.20.0
|
||||
pymysql==1.1.1
|
||||
cryptography==43.0.1
|
||||
python-multipart==0.0.20
|
||||
sse-starlette==2.1.3
|
||||
openai==1.52.2
|
||||
anthropic==0.40.0
|
||||
tiktoken==0.8.0
|
||||
python-dotenv==1.0.1
|
||||
aiohttp==3.9.0
|
||||
httpx==0.27.2
|
||||
tavily-python==0.5.0
|
||||
pyyaml==6.0.2
|
||||
python-jose[cryptography]==3.3.0
|
||||
0
services/__init__.py
Normal file
0
services/__init__.py
Normal file
71
services/api_key_service.py
Normal file
71
services/api_key_service.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from cryptography.fernet import Fernet, InvalidToken
|
||||
from sqlalchemy.orm import Session
|
||||
from db_models import ApiKey
|
||||
import os
|
||||
|
||||
# Initialize the encryption key from environment or generate a new one
|
||||
ENCRYPTION_KEY = os.getenv("ENCRYPTION_KEY", Fernet.generate_key().decode())
|
||||
cipher_suite = Fernet(ENCRYPTION_KEY.encode())
|
||||
|
||||
|
||||
class ApiKeyService:
|
||||
"""
|
||||
Service for managing API keys in the database
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def encrypt_api_key(api_key: str) -> str:
|
||||
"""
|
||||
Encrypt an API key
|
||||
"""
|
||||
encrypted_key = cipher_suite.encrypt(api_key.encode())
|
||||
return encrypted_key.decode()
|
||||
|
||||
@staticmethod
|
||||
def decrypt_api_key(encrypted_api_key: str) -> str:
|
||||
"""
|
||||
Decrypt an API key
|
||||
"""
|
||||
decrypted_key = cipher_suite.decrypt(encrypted_api_key.encode())
|
||||
return decrypted_key.decode()
|
||||
|
||||
@staticmethod
|
||||
def get_api_key(db: Session, provider: str) -> str:
|
||||
"""
|
||||
Retrieve and decrypt an API key for a provider
|
||||
"""
|
||||
api_key_record = db.query(ApiKey).filter(ApiKey.provider == provider).first()
|
||||
if not api_key_record or not api_key_record.api_key_encrypted:
|
||||
return None
|
||||
try:
|
||||
return ApiKeyService.decrypt_api_key(api_key_record.api_key_encrypted)
|
||||
except InvalidToken:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def set_api_key(db: Session, provider: str, api_key: str) -> bool:
|
||||
"""
|
||||
Encrypt and store an API key for a provider
|
||||
"""
|
||||
encrypted_key = ApiKeyService.encrypt_api_key(api_key)
|
||||
|
||||
# Check if record exists
|
||||
api_key_record = db.query(ApiKey).filter(ApiKey.provider == provider).first()
|
||||
|
||||
if api_key_record:
|
||||
# Update existing record
|
||||
api_key_record.api_key_encrypted = encrypted_key
|
||||
else:
|
||||
# Create new record
|
||||
api_key_record = ApiKey(
|
||||
provider=provider,
|
||||
api_key_encrypted=encrypted_key
|
||||
)
|
||||
db.add(api_key_record)
|
||||
|
||||
try:
|
||||
db.commit()
|
||||
return True
|
||||
except Exception:
|
||||
db.rollback()
|
||||
return False
|
||||
100
services/config_service.py
Normal file
100
services/config_service.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
import yaml
|
||||
from cryptography.fernet import Fernet, InvalidToken
|
||||
|
||||
CONFIG_PATH = Path(os.getenv("CONFIG_PATH", "/app/config/dialectica.yaml"))
|
||||
|
||||
# Reuse the same encryption key used for API keys
|
||||
_ENCRYPTION_KEY = os.getenv("ENCRYPTION_KEY", "")
|
||||
_cipher = Fernet(_ENCRYPTION_KEY.encode()) if _ENCRYPTION_KEY else None
|
||||
|
||||
# Fields that should be encrypted in the YAML file
|
||||
_SECRET_FIELDS = {"password"}
|
||||
|
||||
|
||||
def _encrypt(value: str) -> str:
|
||||
if not _cipher or not value:
|
||||
return value
|
||||
return "ENC:" + _cipher.encrypt(value.encode()).decode()
|
||||
|
||||
|
||||
def _decrypt(value: str) -> str:
|
||||
if not _cipher or not isinstance(value, str) or not value.startswith("ENC:"):
|
||||
return value
|
||||
try:
|
||||
return _cipher.decrypt(value[4:].encode()).decode()
|
||||
except InvalidToken:
|
||||
return value
|
||||
|
||||
|
||||
def _encrypt_secrets(data: dict) -> dict:
|
||||
"""Deep-copy dict, encrypting secret fields."""
|
||||
out = {}
|
||||
for k, v in data.items():
|
||||
if isinstance(v, dict):
|
||||
out[k] = _encrypt_secrets(v)
|
||||
elif k in _SECRET_FIELDS and isinstance(v, str) and not v.startswith("ENC:"):
|
||||
out[k] = _encrypt(v)
|
||||
else:
|
||||
out[k] = v
|
||||
return out
|
||||
|
||||
|
||||
def _decrypt_secrets(data: dict) -> dict:
|
||||
"""Deep-copy dict, decrypting secret fields."""
|
||||
out = {}
|
||||
for k, v in data.items():
|
||||
if isinstance(v, dict):
|
||||
out[k] = _decrypt_secrets(v)
|
||||
elif k in _SECRET_FIELDS and isinstance(v, str):
|
||||
out[k] = _decrypt(v)
|
||||
else:
|
||||
out[k] = v
|
||||
return out
|
||||
|
||||
|
||||
class ConfigService:
|
||||
"""Read / write config/dialectica.yaml."""
|
||||
|
||||
@staticmethod
|
||||
def load() -> dict:
|
||||
"""Load config, returning decrypted values. Empty dict if file missing."""
|
||||
if not CONFIG_PATH.exists():
|
||||
return {}
|
||||
with open(CONFIG_PATH) as f:
|
||||
raw = yaml.safe_load(f) or {}
|
||||
return _decrypt_secrets(raw)
|
||||
|
||||
@staticmethod
|
||||
def save(config: dict):
|
||||
"""Save config, encrypting secret fields."""
|
||||
CONFIG_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
encrypted = _encrypt_secrets(config)
|
||||
with open(CONFIG_PATH, "w") as f:
|
||||
yaml.dump(encrypted, f, default_flow_style=False, allow_unicode=True)
|
||||
|
||||
@staticmethod
|
||||
def is_db_configured() -> bool:
|
||||
config = ConfigService.load()
|
||||
db = config.get("database", {})
|
||||
return bool(db.get("host") and db.get("database"))
|
||||
|
||||
@staticmethod
|
||||
def get_database_url() -> str | None:
|
||||
config = ConfigService.load()
|
||||
db = config.get("database", {})
|
||||
if not (db.get("host") and db.get("database")):
|
||||
return None
|
||||
user = db.get("user", "root")
|
||||
password = db.get("password", "")
|
||||
host = db["host"]
|
||||
port = db.get("port", 3306)
|
||||
database = db["database"]
|
||||
return f"mysql+pymysql://{quote_plus(user)}:{quote_plus(password)}@{host}:{port}/{database}"
|
||||
|
||||
@staticmethod
|
||||
def is_initialized() -> bool:
|
||||
return ConfigService.load().get("initialized", False)
|
||||
104
services/search_service.py
Normal file
104
services/search_service.py
Normal file
@@ -0,0 +1,104 @@
|
||||
from typing import List, Optional
|
||||
from models.debate import SearchResult, SearchEvidence
|
||||
|
||||
|
||||
class SearchService:
|
||||
"""
|
||||
Tavily web search wrapper for debate research.
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: str):
|
||||
from tavily import TavilyClient
|
||||
self.client = TavilyClient(api_key=api_key)
|
||||
|
||||
def search(self, query: str, max_results: int = 5) -> List[SearchResult]:
|
||||
"""
|
||||
Perform a web search using Tavily and return structured results.
|
||||
"""
|
||||
try:
|
||||
response = self.client.search(
|
||||
query=query,
|
||||
max_results=max_results,
|
||||
search_depth="basic"
|
||||
)
|
||||
results = []
|
||||
for item in response.get("results", []):
|
||||
results.append(SearchResult(
|
||||
title=item.get("title", ""),
|
||||
url=item.get("url", ""),
|
||||
snippet=item.get("content", "")[:500], # Truncate to avoid token bloat
|
||||
score=item.get("score")
|
||||
))
|
||||
return results
|
||||
except Exception as e:
|
||||
print(f"Tavily search error: {e}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def format_results_for_context(evidence: SearchEvidence) -> str:
|
||||
"""
|
||||
Format search results into a string suitable for injecting into the LLM context.
|
||||
"""
|
||||
if not evidence or not evidence.results:
|
||||
return ""
|
||||
lines = [f"\n[网络搜索结果] 搜索词: \"{evidence.query}\""]
|
||||
for i, r in enumerate(evidence.results, 1):
|
||||
lines.append(f" {i}. {r.title}")
|
||||
lines.append(f" {r.snippet}")
|
||||
lines.append(f" 来源: {r.url}")
|
||||
lines.append("[搜索结果结束]\n")
|
||||
return "\n".join(lines)
|
||||
|
||||
@staticmethod
|
||||
def generate_search_query(topic: str, last_opponent_argument: Optional[str] = None) -> str:
|
||||
"""
|
||||
Generate a search query from the debate topic and the opponent's last argument.
|
||||
"""
|
||||
if last_opponent_argument:
|
||||
# Extract key phrases from the opponent's argument (first ~100 chars)
|
||||
snippet = last_opponent_argument[:100].strip()
|
||||
return f"{topic} {snippet}"
|
||||
return topic
|
||||
|
||||
@staticmethod
|
||||
def get_tool_definition() -> dict:
|
||||
"""
|
||||
Return the web_search tool definition for function calling.
|
||||
"""
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "web_search",
|
||||
"description": "Search the web for current information relevant to the debate topic. Use this to find facts, statistics, or recent news that support your argument.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query to look up"
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_tool_definition_anthropic() -> dict:
|
||||
"""
|
||||
Return the web_search tool definition in Anthropic format.
|
||||
"""
|
||||
return {
|
||||
"name": "web_search",
|
||||
"description": "Search the web for current information relevant to the debate topic. Use this to find facts, statistics, or recent news that support your argument.",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query to look up"
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
0
storage/__init__.py
Normal file
0
storage/__init__.py
Normal file
109
storage/database.py
Normal file
109
storage/database.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Text
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from datetime import datetime
|
||||
import json
|
||||
|
||||
from models.debate import DebateSession
|
||||
from exceptions import ServiceNotConfiguredError
|
||||
from services.config_service import ConfigService
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class DebateSessionDB(Base):
|
||||
__tablename__ = "debate_sessions"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
session_id = Column(String(255), unique=True, index=True, nullable=False)
|
||||
topic = Column(Text, nullable=False)
|
||||
participants = Column(Text, nullable=False) # JSON string
|
||||
constraints = Column(Text, nullable=False) # JSON string
|
||||
rounds = Column(Text, nullable=False) # JSON string
|
||||
status = Column(String(50), nullable=False)
|
||||
created_at = Column(DateTime, nullable=False)
|
||||
completed_at = Column(DateTime, nullable=True)
|
||||
summary = Column(Text, nullable=True)
|
||||
evidence_library = Column(Text, nullable=True) # JSON string
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Lazy engine / session factory
|
||||
# ---------------------------------------------------------------------------
|
||||
_engine = None
|
||||
_SessionLocal = None
|
||||
|
||||
|
||||
def _get_engine():
|
||||
from sqlalchemy import create_engine
|
||||
global _engine
|
||||
if _engine is None:
|
||||
db_url = ConfigService.get_database_url()
|
||||
if not db_url:
|
||||
raise ServiceNotConfiguredError("数据库未配置")
|
||||
_engine = create_engine(db_url)
|
||||
return _engine
|
||||
|
||||
|
||||
def _get_session_factory():
|
||||
global _SessionLocal
|
||||
if _SessionLocal is None:
|
||||
_SessionLocal = sessionmaker(
|
||||
autocommit=False, autoflush=False, bind=_get_engine()
|
||||
)
|
||||
return _SessionLocal
|
||||
|
||||
|
||||
def init_db():
|
||||
"""Create all tables if DB is configured; silently skip otherwise."""
|
||||
if not ConfigService.is_db_configured():
|
||||
print("WARNING: Database not configured, skipping table creation.")
|
||||
return
|
||||
try:
|
||||
from db_models import Base as ApiBase
|
||||
engine = _get_engine()
|
||||
Base.metadata.create_all(bind=engine)
|
||||
ApiBase.metadata.create_all(bind=engine)
|
||||
except Exception as e:
|
||||
global _engine, _SessionLocal
|
||||
print(f"WARNING: Database connection failed, skipping table creation: {e}")
|
||||
if _engine is not None:
|
||||
_engine.dispose()
|
||||
_engine = None
|
||||
_SessionLocal = None
|
||||
|
||||
|
||||
def debate_session_from_db(db_session) -> DebateSession:
|
||||
"""Convert database session to Pydantic model."""
|
||||
evidence_library = []
|
||||
if db_session.evidence_library:
|
||||
evidence_library = json.loads(db_session.evidence_library)
|
||||
|
||||
return DebateSession(
|
||||
session_id=db_session.session_id,
|
||||
topic=db_session.topic,
|
||||
participants=json.loads(db_session.participants),
|
||||
constraints=json.loads(db_session.constraints),
|
||||
rounds=json.loads(db_session.rounds),
|
||||
status=db_session.status,
|
||||
created_at=db_session.created_at,
|
||||
completed_at=db_session.completed_at,
|
||||
summary=db_session.summary,
|
||||
evidence_library=evidence_library
|
||||
)
|
||||
|
||||
|
||||
def debate_session_to_db(session: DebateSession) -> DebateSessionDB:
|
||||
"""Convert Pydantic model to database model."""
|
||||
return DebateSessionDB(
|
||||
session_id=session.session_id,
|
||||
topic=session.topic,
|
||||
participants=json.dumps([p.dict() for p in session.participants]),
|
||||
constraints=json.dumps(session.constraints.dict()),
|
||||
rounds=json.dumps([r.dict() for r in session.rounds], default=str),
|
||||
status=session.status,
|
||||
created_at=session.created_at,
|
||||
completed_at=session.completed_at,
|
||||
summary=session.summary,
|
||||
evidence_library=json.dumps([e.dict() for e in session.evidence_library], default=str) if session.evidence_library else None
|
||||
)
|
||||
67
storage/session_manager.py
Normal file
67
storage/session_manager.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from datetime import datetime
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from models.debate import DebateSession
|
||||
from storage.database import DebateSessionDB, debate_session_from_db, debate_session_to_db
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""
|
||||
Manages debate sessions in storage
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
async def save_session(self, db: Session, session: DebateSession):
|
||||
"""
|
||||
Save a debate session to the database
|
||||
"""
|
||||
db_session = debate_session_to_db(session)
|
||||
db.add(db_session)
|
||||
db.commit()
|
||||
db.refresh(db_session)
|
||||
|
||||
async def get_session(self, db: Session, session_id: str) -> Optional[DebateSession]:
|
||||
"""
|
||||
Retrieve a debate session from the database
|
||||
"""
|
||||
db_session = db.query(DebateSessionDB).filter(DebateSessionDB.session_id == session_id).first()
|
||||
if not db_session:
|
||||
return None
|
||||
return debate_session_from_db(db_session)
|
||||
|
||||
async def update_session(self, db: Session, session: DebateSession):
|
||||
"""
|
||||
Update an existing debate session in the database
|
||||
"""
|
||||
db_session = db.query(DebateSessionDB).filter(DebateSessionDB.session_id == session.session_id).first()
|
||||
if db_session:
|
||||
db_session.topic = session.topic
|
||||
db_session.participants = json.dumps([p.dict() for p in session.participants])
|
||||
db_session.constraints = json.dumps(session.constraints.dict())
|
||||
db_session.rounds = json.dumps([r.dict() for r in session.rounds], default=str)
|
||||
db_session.status = session.status
|
||||
db_session.completed_at = session.completed_at
|
||||
db_session.summary = session.summary
|
||||
db_session.evidence_library = json.dumps([e.dict() for e in session.evidence_library], default=str) if session.evidence_library else None
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_session)
|
||||
|
||||
async def list_sessions(self, db: Session):
|
||||
"""
|
||||
List all debate sessions
|
||||
"""
|
||||
db_sessions = db.query(DebateSessionDB).all()
|
||||
return [
|
||||
{
|
||||
"session_id": db_session.session_id,
|
||||
"topic": db_session.topic,
|
||||
"status": db_session.status,
|
||||
"created_at": db_session.created_at.isoformat() if db_session.created_at else None
|
||||
}
|
||||
for db_session in db_sessions
|
||||
]
|
||||
0
utils/__init__.py
Normal file
0
utils/__init__.py
Normal file
39
utils/summarizer.py
Normal file
39
utils/summarizer.py
Normal file
@@ -0,0 +1,39 @@
|
||||
async def summarize_debate(session):
|
||||
"""
|
||||
Generate a summary of the debate session
|
||||
"""
|
||||
if not session.rounds:
|
||||
return "No rounds were completed in this debate."
|
||||
|
||||
# Extract key points from each side
|
||||
pro_points = []
|
||||
con_points = []
|
||||
|
||||
for round_data in session.rounds:
|
||||
if round_data.stance.value == "pro":
|
||||
pro_points.append(round_data.content)
|
||||
else:
|
||||
con_points.append(round_data.content)
|
||||
|
||||
# Create a summary
|
||||
summary_parts = [
|
||||
f"辩论主题: {session.topic}",
|
||||
"",
|
||||
"正方主要观点:",
|
||||
]
|
||||
|
||||
for i, point in enumerate(pro_points, 1):
|
||||
summary_parts.append(f"{i}. {point[:100]}...") # Truncate for brevity
|
||||
|
||||
summary_parts.append("")
|
||||
summary_parts.append("反方主要观点:")
|
||||
|
||||
for i, point in enumerate(con_points, 1):
|
||||
summary_parts.append(f"{i}. {point[:100]}...") # Truncate for brevity
|
||||
|
||||
summary_parts.append("")
|
||||
summary_parts.append("总结: 本次辩论完成了 {} 轮,双方就 '{}' 主题进行了充分的讨论。".format(
|
||||
len(session.rounds), session.topic
|
||||
))
|
||||
|
||||
return "\n".join(summary_parts)
|
||||
Reference in New Issue
Block a user