init
This commit is contained in:
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
/.idea/
|
||||||
16
Dockerfile
Normal file
16
Dockerfile
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
FROM python:3.13-slim
|
||||||
|
|
||||||
|
ENV PYTHONDONTWRITEBYTECODE 1
|
||||||
|
ENV PYTHONUNBUFFERED 1
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY requirements.txt ./requirements.txt
|
||||||
|
|
||||||
|
RUN pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
COPY . /app/
|
||||||
|
|
||||||
|
EXPOSE 5001
|
||||||
|
|
||||||
|
CMD ["python3", "app.py"]
|
||||||
4
api/__init__.py
Normal file
4
api/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
from api.debates import router as debates_router
|
||||||
|
from api.api_keys import router as api_keys_router
|
||||||
|
from api.models import router as models_router
|
||||||
|
from api.setup import router as setup_router
|
||||||
113
api/api_keys.py
Normal file
113
api/api_keys.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
from fastapi import APIRouter, Depends, Form, HTTPException
|
||||||
|
|
||||||
|
from middleware.auth import require_auth
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
from services.api_key_service import ApiKeyService
|
||||||
|
from db_models import get_db
|
||||||
|
|
||||||
|
router = APIRouter(tags=["api-keys"])
|
||||||
|
|
||||||
|
|
||||||
|
async def validate_api_key(provider: str, api_key: str):
|
||||||
|
"""
|
||||||
|
Validate an API key by listing models from the provider.
|
||||||
|
All providers validated the same way: if we can list models, the key is valid.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if provider == "openai":
|
||||||
|
import openai
|
||||||
|
async with openai.AsyncOpenAI(api_key=api_key, timeout=10.0) as client:
|
||||||
|
await client.models.list()
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Claude, Qwen, DeepSeek: GET their models endpoint via aiohttp
|
||||||
|
endpoints = {
|
||||||
|
"claude": {
|
||||||
|
"url": "https://api.anthropic.com/v1/models",
|
||||||
|
"headers": {
|
||||||
|
"x-api-key": api_key,
|
||||||
|
"anthropic-version": "2023-06-01"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"qwen": {
|
||||||
|
"url": "https://dashscope.aliyuncs.com/compatible-mode/v1/models",
|
||||||
|
"headers": {"Authorization": f"Bearer {api_key}"}
|
||||||
|
},
|
||||||
|
"deepseek": {
|
||||||
|
"url": "https://api.deepseek.com/v1/models",
|
||||||
|
"headers": {"Authorization": f"Bearer {api_key}"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if provider in endpoints:
|
||||||
|
timeout = aiohttp.ClientTimeout(total=10)
|
||||||
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||||
|
ep = endpoints[provider]
|
||||||
|
async with session.get(ep["url"], headers=ep["headers"]) as response:
|
||||||
|
return response.status == 200
|
||||||
|
|
||||||
|
# Tavily: validate by calling REST API directly (no tavily package needed)
|
||||||
|
if provider == "tavily":
|
||||||
|
timeout = aiohttp.ClientTimeout(total=10)
|
||||||
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||||
|
async with session.post(
|
||||||
|
"https://api.tavily.com/search",
|
||||||
|
json={"api_key": api_key, "query": "test", "max_results": 1}
|
||||||
|
) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
data = await response.json()
|
||||||
|
return bool(data.get("results"))
|
||||||
|
return False
|
||||||
|
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/api-keys/{provider}", dependencies=[Depends(require_auth)])
|
||||||
|
async def set_api_key(provider: str, api_key: str = Form(...), db=Depends(get_db)):
|
||||||
|
"""
|
||||||
|
Set API key for a specific provider
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
success = ApiKeyService.set_api_key(db, provider, api_key)
|
||||||
|
if success:
|
||||||
|
return {"message": f"API key for {provider} updated successfully"}
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to update API key")
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/api-keys/{provider}")
|
||||||
|
async def get_api_key(provider: str, db=Depends(get_db)):
|
||||||
|
"""
|
||||||
|
Get API key for a specific provider
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
api_key = ApiKeyService.get_api_key(db, provider)
|
||||||
|
if api_key:
|
||||||
|
return {"provider": provider, "api_key": api_key}
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=404, detail=f"No API key found for {provider}")
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/validate-api-key/{provider}", dependencies=[Depends(require_auth)])
|
||||||
|
async def validate_api_key_endpoint(provider: str, api_key: str = Form(...)):
|
||||||
|
"""
|
||||||
|
Validate an API key by making a test request to the provider
|
||||||
|
This endpoint is used by the frontend to validate API keys without CORS issues
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
is_valid = await validate_api_key(provider, api_key)
|
||||||
|
if is_valid:
|
||||||
|
return {"valid": True, "message": f"Valid {provider} API key"}
|
||||||
|
else:
|
||||||
|
return {"valid": False, "message": f"Invalid {provider} API key"}
|
||||||
|
except Exception as e:
|
||||||
|
return {"valid": False, "message": str(e)}
|
||||||
158
api/debates.py
Normal file
158
api/debates.py
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
|
||||||
|
from middleware.auth import require_auth
|
||||||
|
from sse_starlette.sse import EventSourceResponse
|
||||||
|
from typing import Dict, Any
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
|
||||||
|
from orchestrator.debate_orchestrator import DebateOrchestrator
|
||||||
|
from models.debate import DebateRequest
|
||||||
|
from storage.session_manager import SessionManager
|
||||||
|
from db_models import get_db
|
||||||
|
|
||||||
|
router = APIRouter(tags=["debates"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/debate/create", dependencies=[Depends(require_auth)])
|
||||||
|
async def create_debate(debate_request: DebateRequest, db=Depends(get_db)) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Create a new debate session with specified parameters
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
orchestrator = DebateOrchestrator(db)
|
||||||
|
session_id = await orchestrator.create_session(debate_request)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
"status": "created",
|
||||||
|
"message": f"Debate session {session_id} created successfully"
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/debate/{session_id}")
|
||||||
|
async def get_debate(session_id: str, db=Depends(get_db)) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get the current state of a debate session
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
orchestrator = DebateOrchestrator(db)
|
||||||
|
session = await orchestrator.get_session_status(session_id)
|
||||||
|
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Debate session {session_id} not found")
|
||||||
|
|
||||||
|
return session.dict()
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/debate/{session_id}/start", dependencies=[Depends(require_auth)])
|
||||||
|
async def start_debate(session_id: str, db=Depends(get_db)) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Start a debate session and stream the results
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
orchestrator = DebateOrchestrator(db)
|
||||||
|
session = await orchestrator.run_debate(session_id)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
"status": session.status,
|
||||||
|
"message": f"Debate session {session_id} completed"
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/debate/{session_id}", dependencies=[Depends(require_auth)])
|
||||||
|
async def end_debate(session_id: str, db=Depends(get_db)) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
End a debate session prematurely
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
orchestrator = DebateOrchestrator(db)
|
||||||
|
await orchestrator.terminate_session(session_id)
|
||||||
|
|
||||||
|
return {"session_id": session_id, "status": "terminated"}
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/debate/{session_id}/stream")
|
||||||
|
async def stream_debate(session_id: str, db=Depends(get_db)) -> EventSourceResponse:
|
||||||
|
"""
|
||||||
|
Stream debate updates in real-time using Server-Sent Events
|
||||||
|
"""
|
||||||
|
async def event_generator():
|
||||||
|
session_manager = SessionManager()
|
||||||
|
|
||||||
|
# Check if the session exists
|
||||||
|
session = await session_manager.get_session(db, session_id)
|
||||||
|
if not session:
|
||||||
|
yield {"event": "error", "data": json.dumps({"error": f"Session {session_id} not found"})}
|
||||||
|
return
|
||||||
|
|
||||||
|
# Yield initial state
|
||||||
|
yield {"event": "update", "data": json.dumps({
|
||||||
|
"session_id": session_id,
|
||||||
|
"status": session.status,
|
||||||
|
"rounds": [round.dict() for round in session.rounds],
|
||||||
|
"evidence_library": [e.dict() for e in session.evidence_library],
|
||||||
|
"message": "Debate session initialized"
|
||||||
|
}, default=str)}
|
||||||
|
|
||||||
|
last_round_count = len(session.rounds)
|
||||||
|
# Poll until debate completes or times out (max 5 min)
|
||||||
|
for _ in range(150): # 150 × 2s = 300s timeout
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
|
# Reset transaction so we see commits from the run_debate request
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
updated_session = await session_manager.get_session(db, session_id)
|
||||||
|
if not updated_session:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Only yield update when rounds actually changed
|
||||||
|
if len(updated_session.rounds) != last_round_count or updated_session.status != session.status:
|
||||||
|
last_round_count = len(updated_session.rounds)
|
||||||
|
session = updated_session
|
||||||
|
yield {"event": "update", "data": json.dumps({
|
||||||
|
"session_id": session_id,
|
||||||
|
"status": updated_session.status,
|
||||||
|
"rounds": [round.dict() for round in updated_session.rounds],
|
||||||
|
"evidence_library": [e.dict() for e in updated_session.evidence_library],
|
||||||
|
"current_round": len(updated_session.rounds)
|
||||||
|
}, default=str)}
|
||||||
|
|
||||||
|
if updated_session.status in ("completed", "terminated"):
|
||||||
|
yield {"event": "complete", "data": json.dumps({
|
||||||
|
"session_id": session_id,
|
||||||
|
"status": updated_session.status,
|
||||||
|
"summary": updated_session.summary,
|
||||||
|
"rounds": [round.dict() for round in updated_session.rounds],
|
||||||
|
"evidence_library": [e.dict() for e in updated_session.evidence_library]
|
||||||
|
}, default=str)}
|
||||||
|
break
|
||||||
|
|
||||||
|
return EventSourceResponse(event_generator())
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/sessions")
|
||||||
|
async def list_sessions(db=Depends(get_db)) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
List all debate sessions
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
session_manager = SessionManager()
|
||||||
|
sessions = await session_manager.list_sessions(db)
|
||||||
|
return {"sessions": sessions}
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
173
api/models.py
Normal file
173
api/models.py
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
from services.api_key_service import ApiKeyService
|
||||||
|
from db_models import get_db
|
||||||
|
from api.api_keys import validate_api_key
|
||||||
|
|
||||||
|
router = APIRouter(tags=["models"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/models/{provider}")
|
||||||
|
async def get_available_models(provider: str, db=Depends(get_db)):
|
||||||
|
"""
|
||||||
|
Get available models for a specific provider by fetching from their API.
|
||||||
|
Falls back to a curated default list if the API key is missing or the request fails.
|
||||||
|
"""
|
||||||
|
# Curated defaults for each provider — used as fallback
|
||||||
|
default_models = {
|
||||||
|
"openai": [
|
||||||
|
{"model_identifier": "gpt-4o", "display_name": "gpt-4o"},
|
||||||
|
{"model_identifier": "gpt-4o-mini", "display_name": "gpt-4o-mini"},
|
||||||
|
{"model_identifier": "gpt-4-turbo", "display_name": "gpt-4-turbo"},
|
||||||
|
{"model_identifier": "gpt-4", "display_name": "gpt-4"},
|
||||||
|
{"model_identifier": "o3-mini", "display_name": "o3-mini"},
|
||||||
|
{"model_identifier": "gpt-3.5-turbo", "display_name": "gpt-3.5-turbo"}
|
||||||
|
],
|
||||||
|
"claude": [
|
||||||
|
{"model_identifier": "claude-opus-4-5", "display_name": "claude-opus-4-5"},
|
||||||
|
{"model_identifier": "claude-sonnet-4-5", "display_name": "claude-sonnet-4-5"},
|
||||||
|
{"model_identifier": "claude-3-5-sonnet-20241022", "display_name": "claude-3-5-sonnet-20241022"},
|
||||||
|
{"model_identifier": "claude-3-5-haiku-20241022", "display_name": "claude-3-5-haiku-20241022"},
|
||||||
|
{"model_identifier": "claude-3-opus-20240229", "display_name": "claude-3-opus-20240229"}
|
||||||
|
],
|
||||||
|
"qwen": [
|
||||||
|
{"model_identifier": "qwen3-max", "display_name": "qwen3-max"},
|
||||||
|
{"model_identifier": "qwen3-plus", "display_name": "qwen3-plus"},
|
||||||
|
{"model_identifier": "qwen3-flash", "display_name": "qwen3-flash"},
|
||||||
|
{"model_identifier": "qwen-max", "display_name": "qwen-max"},
|
||||||
|
{"model_identifier": "qwen-plus", "display_name": "qwen-plus"},
|
||||||
|
{"model_identifier": "qwen-turbo", "display_name": "qwen-turbo"}
|
||||||
|
],
|
||||||
|
"deepseek": [
|
||||||
|
{"model_identifier": "deepseek-chat", "display_name": "deepseek-chat"},
|
||||||
|
{"model_identifier": "deepseek-reasoner", "display_name": "deepseek-reasoner"},
|
||||||
|
{"model_identifier": "deepseek-v3", "display_name": "deepseek-v3"},
|
||||||
|
{"model_identifier": "deepseek-r1", "display_name": "deepseek-r1"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
defaults = default_models.get(provider, [])
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Retrieve and decrypt API key
|
||||||
|
decrypted_key = ApiKeyService.get_api_key(db, provider)
|
||||||
|
if not decrypted_key:
|
||||||
|
return {"provider": provider, "models": defaults}
|
||||||
|
|
||||||
|
# ---------- OpenAI ----------
|
||||||
|
if provider == "openai":
|
||||||
|
import openai
|
||||||
|
async with openai.AsyncOpenAI(api_key=decrypted_key, timeout=10.0) as client:
|
||||||
|
response = await client.models.list()
|
||||||
|
|
||||||
|
# Keep only chat / reasoning models, sorted newest-first by created timestamp
|
||||||
|
chat_prefixes = ('gpt-', 'o1', 'o3', 'o4', 'chatgpt')
|
||||||
|
models = []
|
||||||
|
seen = set()
|
||||||
|
for m in sorted(response.data, key=lambda x: x.created, reverse=True):
|
||||||
|
if m.id not in seen and m.id.startswith(chat_prefixes):
|
||||||
|
seen.add(m.id)
|
||||||
|
models.append({"model_identifier": m.id, "display_name": m.id})
|
||||||
|
|
||||||
|
if models:
|
||||||
|
return {"provider": provider, "models": models}
|
||||||
|
|
||||||
|
# ---------- Claude ----------
|
||||||
|
elif provider == "claude":
|
||||||
|
timeout = aiohttp.ClientTimeout(total=10)
|
||||||
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||||
|
headers = {
|
||||||
|
"x-api-key": decrypted_key,
|
||||||
|
"anthropic-version": "2023-06-01"
|
||||||
|
}
|
||||||
|
async with session.get("https://api.anthropic.com/v1/models", headers=headers) as resp:
|
||||||
|
if resp.status == 200:
|
||||||
|
data = await resp.json()
|
||||||
|
models = []
|
||||||
|
for m in data.get("data", []):
|
||||||
|
model_id = m.get("id", "")
|
||||||
|
if model_id.startswith("claude-"):
|
||||||
|
models.append({"model_identifier": model_id, "display_name": model_id})
|
||||||
|
if models:
|
||||||
|
return {"provider": provider, "models": models}
|
||||||
|
else:
|
||||||
|
print(f"Claude models API returned {resp.status}: {await resp.text()}")
|
||||||
|
|
||||||
|
# ---------- Qwen ----------
|
||||||
|
elif provider == "qwen":
|
||||||
|
timeout = aiohttp.ClientTimeout(total=10)
|
||||||
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||||
|
headers = {"Authorization": f"Bearer {decrypted_key}"}
|
||||||
|
async with session.get("https://dashscope.aliyuncs.com/compatible-mode/v1/models", headers=headers) as resp:
|
||||||
|
if resp.status == 200:
|
||||||
|
data = await resp.json()
|
||||||
|
exclude_keywords = [
|
||||||
|
'tts', 'vl', 'ocr', 'image', 'asr', '-mt-', '-mt',
|
||||||
|
'math', 'embed', 'rerank', 'coder', 'translate',
|
||||||
|
's2s', 'deep-search', 'omni', 'gui-'
|
||||||
|
]
|
||||||
|
models = []
|
||||||
|
for m in data.get("data", []):
|
||||||
|
model_id = m.get("id", "")
|
||||||
|
if not model_id:
|
||||||
|
continue
|
||||||
|
if not (model_id.startswith("qwen") or model_id.startswith("qwq")):
|
||||||
|
continue
|
||||||
|
if any(kw in model_id for kw in exclude_keywords):
|
||||||
|
continue
|
||||||
|
models.append({"model_identifier": model_id, "display_name": model_id})
|
||||||
|
if models:
|
||||||
|
return {"provider": provider, "models": models}
|
||||||
|
else:
|
||||||
|
print(f"Qwen models API returned {resp.status}: {await resp.text()}")
|
||||||
|
|
||||||
|
# ---------- DeepSeek ----------
|
||||||
|
elif provider == "deepseek":
|
||||||
|
timeout = aiohttp.ClientTimeout(total=10)
|
||||||
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||||
|
headers = {"Authorization": f"Bearer {decrypted_key}"}
|
||||||
|
async with session.get("https://api.deepseek.com/v1/models", headers=headers) as resp:
|
||||||
|
if resp.status == 200:
|
||||||
|
data = await resp.json()
|
||||||
|
models = []
|
||||||
|
for m in data.get("data", []):
|
||||||
|
model_id = m.get("id", "")
|
||||||
|
if model_id.startswith("deepseek"):
|
||||||
|
models.append({"model_identifier": model_id, "display_name": model_id})
|
||||||
|
if models:
|
||||||
|
return {"provider": provider, "models": models}
|
||||||
|
else:
|
||||||
|
print(f"DeepSeek models API returned {resp.status}: {await resp.text()}")
|
||||||
|
|
||||||
|
# API fetch succeeded but returned empty list, or unknown provider — use defaults
|
||||||
|
return {"provider": provider, "models": defaults}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error fetching models for {provider}: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
return {"provider": provider, "models": defaults}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/providers")
|
||||||
|
async def get_available_providers(db=Depends(get_db)):
|
||||||
|
"""
|
||||||
|
Get all providers that have valid API keys set
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
available_providers = []
|
||||||
|
for provider in ("openai", "claude", "qwen", "deepseek", "tavily"):
|
||||||
|
decrypted_key = ApiKeyService.get_api_key(db, provider)
|
||||||
|
if not decrypted_key:
|
||||||
|
continue
|
||||||
|
is_valid = await validate_api_key(provider, decrypted_key)
|
||||||
|
if is_valid:
|
||||||
|
available_providers.append({
|
||||||
|
"provider": provider,
|
||||||
|
"has_valid_key": True
|
||||||
|
})
|
||||||
|
|
||||||
|
return {"providers": available_providers}
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
192
api/setup.py
Normal file
192
api/setup.py
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
from urllib.parse import quote_plus
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from services.config_service import ConfigService
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/setup", tags=["setup"])
|
||||||
|
|
||||||
|
PASSWORD_PLACEHOLDER = "********"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Request / response models
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
class DatabaseConfig(BaseModel):
|
||||||
|
host: str
|
||||||
|
port: int = 3306
|
||||||
|
user: str = "root"
|
||||||
|
password: str = ""
|
||||||
|
database: str = "dialectica"
|
||||||
|
|
||||||
|
|
||||||
|
class KeycloakConfig(BaseModel):
|
||||||
|
host: str = ""
|
||||||
|
realm: str = ""
|
||||||
|
client_id: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class TlsConfig(BaseModel):
|
||||||
|
cert_path: str = ""
|
||||||
|
key_path: str = ""
|
||||||
|
force_https: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class FullConfig(BaseModel):
|
||||||
|
database: Optional[DatabaseConfig] = None
|
||||||
|
keycloak: Optional[KeycloakConfig] = None
|
||||||
|
tls: Optional[TlsConfig] = None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Access control dependency
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
async def setup_guard(request: Request):
|
||||||
|
"""Three-phase access control for setup routes.
|
||||||
|
|
||||||
|
1. Not initialized → only localhost allowed
|
||||||
|
2. ENV_MODE=dev → open
|
||||||
|
3. ENV_MODE=prod → Keycloak admin JWT required
|
||||||
|
"""
|
||||||
|
config = ConfigService.load()
|
||||||
|
initialized = config.get("initialized", False)
|
||||||
|
env_mode = os.getenv("ENV_MODE", "dev")
|
||||||
|
|
||||||
|
if env_mode == "dev":
|
||||||
|
return # dev mode: no auth needed, even before initialisation
|
||||||
|
|
||||||
|
if not initialized:
|
||||||
|
# prod + not initialised: only localhost may configure
|
||||||
|
client_ip = request.client.host
|
||||||
|
if client_ip not in ("127.0.0.1", "::1"):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail="初次设置仅允许从本机访问",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# prod → delegate to Keycloak middleware (Phase 3)
|
||||||
|
from app.middleware.auth import require_admin
|
||||||
|
await require_admin(request, config)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Routes
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
@router.get("/status")
|
||||||
|
async def setup_status():
|
||||||
|
"""Return current system initialisation state, including KC info for OIDC."""
|
||||||
|
config = ConfigService.load()
|
||||||
|
env_mode = os.getenv("ENV_MODE", "dev")
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"initialized": config.get("initialized", False),
|
||||||
|
"env_mode": env_mode,
|
||||||
|
"db_configured": ConfigService.is_db_configured(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Include Keycloak info so the frontend can build OIDC config
|
||||||
|
kc = config.get("keycloak", {})
|
||||||
|
if env_mode == "prod" and kc.get("host") and kc.get("realm"):
|
||||||
|
result["keycloak"] = {
|
||||||
|
"authority": f"{kc['host']}/realms/{kc['realm']}",
|
||||||
|
"client_id": kc.get("client_id", ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/config", dependencies=[Depends(setup_guard)])
|
||||||
|
async def get_config():
|
||||||
|
"""Return full config with passwords replaced by placeholder."""
|
||||||
|
config = ConfigService.load()
|
||||||
|
if "database" in config and config["database"].get("password"):
|
||||||
|
config["database"]["password"] = PASSWORD_PLACEHOLDER
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/config", dependencies=[Depends(setup_guard)])
|
||||||
|
async def update_config(payload: FullConfig):
|
||||||
|
"""Merge submitted config sections into the YAML file."""
|
||||||
|
config = ConfigService.load()
|
||||||
|
|
||||||
|
if payload.database is not None:
|
||||||
|
db_dict = payload.database.model_dump()
|
||||||
|
# If password is the placeholder, keep the existing real password
|
||||||
|
if db_dict.get("password") == PASSWORD_PLACEHOLDER:
|
||||||
|
db_dict["password"] = config.get("database", {}).get("password", "")
|
||||||
|
config["database"] = db_dict
|
||||||
|
if payload.keycloak is not None:
|
||||||
|
config["keycloak"] = payload.keycloak.model_dump()
|
||||||
|
if payload.tls is not None:
|
||||||
|
config["tls"] = payload.tls.model_dump()
|
||||||
|
|
||||||
|
ConfigService.save(config)
|
||||||
|
return {"message": "配置已保存"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/test-db", dependencies=[Depends(setup_guard)])
|
||||||
|
async def test_db_connection(db_config: DatabaseConfig):
|
||||||
|
"""Test a database connection with the provided parameters (no save)."""
|
||||||
|
from sqlalchemy import create_engine, text
|
||||||
|
|
||||||
|
password = db_config.password
|
||||||
|
# If password is the placeholder, use the real password from config
|
||||||
|
if password == PASSWORD_PLACEHOLDER:
|
||||||
|
password = ConfigService.load().get("database", {}).get("password", "")
|
||||||
|
|
||||||
|
url = (
|
||||||
|
f"mysql+pymysql://{quote_plus(db_config.user)}:{quote_plus(password)}"
|
||||||
|
f"@{db_config.host}:{db_config.port}/{db_config.database}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
engine = create_engine(url, pool_pre_ping=True)
|
||||||
|
with engine.connect() as conn:
|
||||||
|
conn.execute(text("SELECT 1"))
|
||||||
|
engine.dispose()
|
||||||
|
return {"success": True, "message": "数据库连接成功"}
|
||||||
|
except Exception as e:
|
||||||
|
return {"success": False, "message": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/test-keycloak", dependencies=[Depends(setup_guard)])
|
||||||
|
async def test_keycloak(kc_config: KeycloakConfig):
|
||||||
|
"""Test Keycloak connectivity by fetching the OIDC discovery document."""
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
well_known = (
|
||||||
|
f"{kc_config.host}/realms/{kc_config.realm}"
|
||||||
|
f"/.well-known/openid-configuration"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=10) as client:
|
||||||
|
resp = await client.get(well_known)
|
||||||
|
if resp.status_code == 200:
|
||||||
|
return {"success": True, "message": "Keycloak 连通正常"}
|
||||||
|
return {"success": False, "message": f"HTTP {resp.status_code}"}
|
||||||
|
except Exception as e:
|
||||||
|
return {"success": False, "message": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/initialize", dependencies=[Depends(setup_guard)])
|
||||||
|
async def initialize():
|
||||||
|
"""Mark system as initialised and reload DB connection."""
|
||||||
|
config = ConfigService.load()
|
||||||
|
|
||||||
|
if not ConfigService.is_db_configured():
|
||||||
|
raise HTTPException(status_code=400, detail="请先配置数据库连接")
|
||||||
|
|
||||||
|
# Reload DB engine so business routes can start working
|
||||||
|
from app.db_models import reload_db_connection
|
||||||
|
from app.storage.database import init_db
|
||||||
|
|
||||||
|
reload_db_connection()
|
||||||
|
init_db()
|
||||||
|
|
||||||
|
config["initialized"] = True
|
||||||
|
ConfigService.save(config)
|
||||||
|
|
||||||
|
return {"message": "系统初始化完成", "initialized": True}
|
||||||
59
app.py
Normal file
59
app.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from starlette.responses import JSONResponse
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
from storage.database import init_db
|
||||||
|
from exceptions import ServiceNotConfiguredError
|
||||||
|
from middleware.config_guard import ConfigGuardMiddleware
|
||||||
|
from api import debates_router, api_keys_router, models_router, setup_router
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="Dialectica - Multi-Model Debate Framework",
|
||||||
|
description="A framework for structured debates between multiple language models",
|
||||||
|
version="0.1.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add CORS middleware
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"], # Allow all origins for development
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"], # Allow all methods
|
||||||
|
allow_headers=["*"], # Allow all headers
|
||||||
|
# Add exposed headers to allow frontend to access response headers
|
||||||
|
expose_headers=["Access-Control-Allow-Origin", "Access-Control-Allow-Credentials"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Config guard: return 503 for business routes when DB not configured
|
||||||
|
app.add_middleware(ConfigGuardMiddleware)
|
||||||
|
|
||||||
|
|
||||||
|
@app.exception_handler(ServiceNotConfiguredError)
|
||||||
|
async def not_configured_handler(request, exc):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=503,
|
||||||
|
content={"error_code": "SERVICE_NOT_CONFIGURED", "detail": str(exc)},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
def startup_event():
|
||||||
|
"""Initialize database on startup (skipped if not configured)."""
|
||||||
|
init_db()
|
||||||
|
|
||||||
|
|
||||||
|
# Register routers
|
||||||
|
app.include_router(debates_router)
|
||||||
|
app.include_router(api_keys_router)
|
||||||
|
app.include_router(models_router)
|
||||||
|
app.include_router(setup_router)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def root():
|
||||||
|
return {"message": "Welcome to Dialectica - Multi-Model Debate Framework"}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8020)
|
||||||
79
db_models/__init__.py
Normal file
79
db_models/__init__.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
from sqlalchemy import Column, Integer, String, Boolean, DateTime, Text
|
||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from exceptions import ServiceNotConfiguredError
|
||||||
|
from services.config_service import ConfigService
|
||||||
|
|
||||||
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
|
class ApiKey(Base):
|
||||||
|
__tablename__ = "api_keys"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
|
provider = Column(String(50), unique=True, index=True, nullable=False)
|
||||||
|
api_key_encrypted = Column(Text, nullable=False) # Encrypted API key
|
||||||
|
created_at = Column(DateTime, default=datetime.utcnow)
|
||||||
|
updated_at = Column(DateTime, default=datetime.utcnow)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelConfig(Base):
|
||||||
|
__tablename__ = "model_configs"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
|
provider = Column(String(50), nullable=False)
|
||||||
|
model_name = Column(String(100), nullable=False)
|
||||||
|
display_name = Column(String(100)) # Optional display name
|
||||||
|
is_active = Column(Boolean, default=True)
|
||||||
|
created_at = Column(DateTime, default=datetime.utcnow)
|
||||||
|
updated_at = Column(DateTime, default=datetime.utcnow)
|
||||||
|
|
||||||
|
__table_args__ = {'mysql_charset': 'utf8mb4'}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Lazy engine / session factory — only created when first requested
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
_engine = None
|
||||||
|
_SessionLocal = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_engine():
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
global _engine
|
||||||
|
if _engine is None:
|
||||||
|
db_url = ConfigService.get_database_url()
|
||||||
|
if not db_url:
|
||||||
|
raise ServiceNotConfiguredError("数据库未配置")
|
||||||
|
_engine = create_engine(db_url)
|
||||||
|
return _engine
|
||||||
|
|
||||||
|
|
||||||
|
def _get_session_factory():
|
||||||
|
global _SessionLocal
|
||||||
|
if _SessionLocal is None:
|
||||||
|
_SessionLocal = sessionmaker(
|
||||||
|
autocommit=False, autoflush=False, bind=_get_engine()
|
||||||
|
)
|
||||||
|
return _SessionLocal
|
||||||
|
|
||||||
|
|
||||||
|
def reload_db_connection():
|
||||||
|
"""Dispose current engine and reset so next call rebuilds from config."""
|
||||||
|
global _engine, _SessionLocal
|
||||||
|
if _engine is not None:
|
||||||
|
_engine.dispose()
|
||||||
|
_engine = None
|
||||||
|
_SessionLocal = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_db():
|
||||||
|
"""FastAPI dependency that yields a DB session."""
|
||||||
|
session_factory = _get_session_factory()
|
||||||
|
db = session_factory()
|
||||||
|
try:
|
||||||
|
yield db
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
3
exceptions/__init__.py
Normal file
3
exceptions/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
class ServiceNotConfiguredError(Exception):
|
||||||
|
"""Raised when a required service (database, etc.) is not configured."""
|
||||||
|
pass
|
||||||
0
middleware/__init__.py
Normal file
0
middleware/__init__.py
Normal file
114
middleware/auth.py
Normal file
114
middleware/auth.py
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
"""Keycloak JWT authentication middleware."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from fastapi import HTTPException, Request
|
||||||
|
from jose import jwt, JWTError, jwk
|
||||||
|
from jose.utils import base64url_decode
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
# Cache JWKS per (host, realm) to avoid fetching on every request
|
||||||
|
_jwks_cache: dict[str, dict] = {}
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_jwks(kc_host: str, realm: str) -> dict:
|
||||||
|
cache_key = f"{kc_host}/{realm}"
|
||||||
|
if cache_key in _jwks_cache:
|
||||||
|
return _jwks_cache[cache_key]
|
||||||
|
|
||||||
|
url = f"{kc_host}/realms/{realm}/protocol/openid-connect/certs"
|
||||||
|
async with httpx.AsyncClient(timeout=10) as client:
|
||||||
|
resp = await client.get(url)
|
||||||
|
if resp.status_code != 200:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=502,
|
||||||
|
detail=f"无法获取 Keycloak JWKS: HTTP {resp.status_code}",
|
||||||
|
)
|
||||||
|
data = resp.json()
|
||||||
|
_jwks_cache[cache_key] = data
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def _find_rsa_key(jwks: dict, token: str) -> dict | None:
|
||||||
|
"""Find the matching RSA key from JWKS for the token's kid."""
|
||||||
|
unverified_header = jwt.get_unverified_header(token)
|
||||||
|
kid = unverified_header.get("kid")
|
||||||
|
for key in jwks.get("keys", []):
|
||||||
|
if key.get("kid") == kid:
|
||||||
|
return key
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def verify_token(request: Request, kc_host: str, realm: str) -> dict:
|
||||||
|
"""Extract and verify the Bearer JWT from the Authorization header.
|
||||||
|
|
||||||
|
Returns the decoded payload on success.
|
||||||
|
Raises HTTPException(401) on missing/invalid token.
|
||||||
|
"""
|
||||||
|
auth_header = request.headers.get("Authorization", "")
|
||||||
|
if not auth_header.startswith("Bearer "):
|
||||||
|
raise HTTPException(status_code=401, detail="缺少 Authorization Bearer token")
|
||||||
|
|
||||||
|
token = auth_header[7:]
|
||||||
|
|
||||||
|
jwks = await _get_jwks(kc_host, realm)
|
||||||
|
rsa_key = _find_rsa_key(jwks, token)
|
||||||
|
if rsa_key is None:
|
||||||
|
# Clear cache in case keys rotated
|
||||||
|
_jwks_cache.pop(f"{kc_host}/{realm}", None)
|
||||||
|
jwks = await _get_jwks(kc_host, realm)
|
||||||
|
rsa_key = _find_rsa_key(jwks, token)
|
||||||
|
if rsa_key is None:
|
||||||
|
raise HTTPException(status_code=401, detail="无法匹配 JWT 签名密钥")
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(
|
||||||
|
token,
|
||||||
|
rsa_key,
|
||||||
|
algorithms=["RS256"],
|
||||||
|
options={"verify_aud": False}, # Keycloak audience varies by client
|
||||||
|
)
|
||||||
|
return payload
|
||||||
|
except JWTError as e:
|
||||||
|
raise HTTPException(status_code=401, detail=f"JWT 验证失败: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
async def require_auth(request: Request):
|
||||||
|
"""Verify Bearer JWT for write endpoints.
|
||||||
|
|
||||||
|
Dev mode: passthrough (no auth required).
|
||||||
|
Prod mode: validates JWT via Keycloak JWKS.
|
||||||
|
"""
|
||||||
|
if os.getenv("ENV_MODE", "dev") == "dev":
|
||||||
|
return None
|
||||||
|
|
||||||
|
from app.services.config_service import ConfigService
|
||||||
|
config = ConfigService.load()
|
||||||
|
kc = config.get("keycloak", {})
|
||||||
|
if not kc.get("host"):
|
||||||
|
return None # KC not configured – allow access
|
||||||
|
|
||||||
|
return await verify_token(request, kc["host"], kc.get("realm", ""))
|
||||||
|
|
||||||
|
|
||||||
|
async def require_admin(request: Request, config: dict):
|
||||||
|
"""Verify the request carries a valid Keycloak JWT with admin role.
|
||||||
|
|
||||||
|
Raises HTTPException(401/403) on failure.
|
||||||
|
"""
|
||||||
|
kc = config.get("keycloak", {})
|
||||||
|
kc_host = kc.get("host")
|
||||||
|
realm = kc.get("realm")
|
||||||
|
|
||||||
|
if not kc_host or not realm:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=503,
|
||||||
|
detail="Keycloak 未配置,无法进行鉴权",
|
||||||
|
)
|
||||||
|
|
||||||
|
payload = await verify_token(request, kc_host, realm)
|
||||||
|
|
||||||
|
roles = payload.get("realm_access", {}).get("roles", [])
|
||||||
|
if "admin" not in roles:
|
||||||
|
raise HTTPException(status_code=403, detail="需要 admin 角色")
|
||||||
|
|
||||||
|
return payload
|
||||||
32
middleware/config_guard.py
Normal file
32
middleware/config_guard.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.responses import JSONResponse
|
||||||
|
|
||||||
|
from services.config_service import ConfigService
|
||||||
|
|
||||||
|
# Paths that are always accessible, even when DB is not configured
|
||||||
|
_ALLOWED_PREFIXES = ("/api/setup", "/docs", "/openapi", "/redoc")
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigGuardMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""Return 503 for all business routes when the database is not configured."""
|
||||||
|
|
||||||
|
async def dispatch(self, request, call_next):
|
||||||
|
path = request.url.path
|
||||||
|
|
||||||
|
# Always allow: setup routes, root, docs, OPTIONS (CORS preflight)
|
||||||
|
if path == "/" or request.method == "OPTIONS":
|
||||||
|
return await call_next(request)
|
||||||
|
for prefix in _ALLOWED_PREFIXES:
|
||||||
|
if path.startswith(prefix):
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
if not ConfigService.is_db_configured():
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=503,
|
||||||
|
content={
|
||||||
|
"error_code": "SERVICE_NOT_CONFIGURED",
|
||||||
|
"detail": "数据库未配置,请先完成系统初始化",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return await call_next(request)
|
||||||
0
models/__init__.py
Normal file
0
models/__init__.py
Normal file
92
models/debate.py
Normal file
92
models/debate.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import List, Optional
|
||||||
|
from enum import Enum
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
class ModelProvider(str, Enum):
|
||||||
|
OPENAI = "openai"
|
||||||
|
DEEPSEEK = "deepseek"
|
||||||
|
QWEN = "qwen"
|
||||||
|
CLAUDE = "claude"
|
||||||
|
|
||||||
|
|
||||||
|
class DebateStance(str, Enum):
|
||||||
|
PRO = "pro"
|
||||||
|
CON = "con"
|
||||||
|
|
||||||
|
|
||||||
|
class SearchResult(BaseModel):
|
||||||
|
title: str
|
||||||
|
url: str
|
||||||
|
snippet: str
|
||||||
|
score: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
|
class SearchEvidence(BaseModel):
|
||||||
|
query: str
|
||||||
|
results: List[SearchResult]
|
||||||
|
mode: str # "auto", "tool", "both"
|
||||||
|
|
||||||
|
|
||||||
|
class DebateRound(BaseModel):
|
||||||
|
round_number: int
|
||||||
|
speaker: str # Model identifier
|
||||||
|
stance: DebateStance
|
||||||
|
content: str
|
||||||
|
timestamp: datetime
|
||||||
|
token_count: Optional[int] = None
|
||||||
|
search_evidence: Optional[SearchEvidence] = None
|
||||||
|
|
||||||
|
|
||||||
|
class DebateParticipant(BaseModel):
|
||||||
|
model_config = {"protected_namespaces": ()}
|
||||||
|
|
||||||
|
model_identifier: str
|
||||||
|
provider: ModelProvider
|
||||||
|
stance: DebateStance
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class DebateConstraints(BaseModel):
|
||||||
|
max_rounds: int = 5
|
||||||
|
max_tokens_per_turn: int = 500
|
||||||
|
max_total_tokens: Optional[int] = None
|
||||||
|
forbid_repetition: bool = True
|
||||||
|
must_respond_to_opponent: bool = True
|
||||||
|
web_search_enabled: bool = False
|
||||||
|
web_search_mode: str = "auto" # "auto", "tool", "both"
|
||||||
|
|
||||||
|
|
||||||
|
class DebateRequest(BaseModel):
|
||||||
|
topic: str
|
||||||
|
participants: List[DebateParticipant]
|
||||||
|
constraints: DebateConstraints
|
||||||
|
custom_system_prompt: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class EvidenceReference(BaseModel):
|
||||||
|
round_number: int
|
||||||
|
speaker: str
|
||||||
|
stance: DebateStance
|
||||||
|
|
||||||
|
|
||||||
|
class EvidenceEntry(BaseModel):
|
||||||
|
title: str
|
||||||
|
url: str
|
||||||
|
snippet: str
|
||||||
|
score: Optional[float] = None
|
||||||
|
references: List[EvidenceReference]
|
||||||
|
|
||||||
|
|
||||||
|
class DebateSession(BaseModel):
|
||||||
|
session_id: str
|
||||||
|
topic: str
|
||||||
|
participants: List[DebateParticipant]
|
||||||
|
constraints: DebateConstraints
|
||||||
|
rounds: List[DebateRound]
|
||||||
|
status: str # "active", "completed", "terminated"
|
||||||
|
created_at: datetime
|
||||||
|
completed_at: Optional[datetime] = None
|
||||||
|
summary: Optional[str] = None
|
||||||
|
evidence_library: List[EvidenceEntry] = []
|
||||||
0
orchestrator/__init__.py
Normal file
0
orchestrator/__init__.py
Normal file
358
orchestrator/debate_orchestrator.py
Normal file
358
orchestrator/debate_orchestrator.py
Normal file
@@ -0,0 +1,358 @@
|
|||||||
|
import asyncio
|
||||||
|
import uuid
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
from datetime import datetime
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from models.debate import (
|
||||||
|
DebateRequest, DebateSession, DebateRound,
|
||||||
|
DebateStance, DebateParticipant, SearchEvidence, SearchResult,
|
||||||
|
EvidenceEntry, EvidenceReference
|
||||||
|
)
|
||||||
|
from providers.provider_factory import ProviderFactory
|
||||||
|
from storage.session_manager import SessionManager
|
||||||
|
from utils.summarizer import summarize_debate
|
||||||
|
from services.search_service import SearchService
|
||||||
|
from services.api_key_service import ApiKeyService
|
||||||
|
|
||||||
|
|
||||||
|
class DebateOrchestrator:
|
||||||
|
"""
|
||||||
|
Orchestrates the debate between multiple language models
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db: Session):
|
||||||
|
self.db = db
|
||||||
|
self.session_manager = SessionManager()
|
||||||
|
self.provider_factory = ProviderFactory()
|
||||||
|
|
||||||
|
async def create_session(self, debate_request: DebateRequest) -> str:
|
||||||
|
"""
|
||||||
|
Create a new debate session
|
||||||
|
"""
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
session = DebateSession(
|
||||||
|
session_id=session_id,
|
||||||
|
topic=debate_request.topic,
|
||||||
|
participants=debate_request.participants,
|
||||||
|
constraints=debate_request.constraints,
|
||||||
|
rounds=[],
|
||||||
|
status="active",
|
||||||
|
created_at=datetime.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.session_manager.save_session(self.db, session)
|
||||||
|
return session_id
|
||||||
|
|
||||||
|
def _get_search_service(self) -> Optional[SearchService]:
|
||||||
|
"""
|
||||||
|
Get a SearchService instance if Tavily API key is available.
|
||||||
|
"""
|
||||||
|
tavily_key = ApiKeyService.get_api_key(self.db, "tavily")
|
||||||
|
if tavily_key:
|
||||||
|
return SearchService(api_key=tavily_key)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def run_debate(self, session_id: str) -> DebateSession:
|
||||||
|
"""
|
||||||
|
Run the complete debate process
|
||||||
|
"""
|
||||||
|
session = await self.session_manager.get_session(self.db, session_id)
|
||||||
|
if not session:
|
||||||
|
raise ValueError(f"Session {session_id} not found")
|
||||||
|
|
||||||
|
# Initialize providers for each participant
|
||||||
|
providers = {}
|
||||||
|
for participant in session.participants:
|
||||||
|
provider = self.provider_factory.create_provider(
|
||||||
|
self.db,
|
||||||
|
participant.provider,
|
||||||
|
participant.api_key # This can be None, and the provider will fetch from DB
|
||||||
|
)
|
||||||
|
providers[participant.model_identifier] = provider
|
||||||
|
|
||||||
|
# Initialize search service if web search is enabled
|
||||||
|
search_service = None
|
||||||
|
web_search_enabled = session.constraints.web_search_enabled
|
||||||
|
web_search_mode = session.constraints.web_search_mode
|
||||||
|
if web_search_enabled:
|
||||||
|
search_service = self._get_search_service()
|
||||||
|
if not search_service:
|
||||||
|
print("Warning: Web search enabled but no Tavily API key found. Disabling search.")
|
||||||
|
web_search_enabled = False
|
||||||
|
|
||||||
|
# Run the debate rounds
|
||||||
|
for round_num in range(session.constraints.max_rounds):
|
||||||
|
if session.status != "active":
|
||||||
|
break
|
||||||
|
|
||||||
|
# Alternate between participants
|
||||||
|
current_participant = session.participants[round_num % len(session.participants)]
|
||||||
|
provider = providers[current_participant.model_identifier]
|
||||||
|
|
||||||
|
# Perform automatic search if enabled
|
||||||
|
search_evidence = None
|
||||||
|
if web_search_enabled and web_search_mode in ("auto", "both"):
|
||||||
|
search_evidence = self._perform_automatic_search(
|
||||||
|
search_service, session, round_num
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare context for the current turn (with search results if available)
|
||||||
|
context = self._prepare_context(session, current_participant.stance, search_evidence)
|
||||||
|
|
||||||
|
# Determine if we should use tool calling for this round
|
||||||
|
use_tool_calling = (
|
||||||
|
web_search_enabled
|
||||||
|
and web_search_mode in ("tool", "both")
|
||||||
|
and provider.supports_tools()
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_tool_calling:
|
||||||
|
response, tool_evidence = await self._handle_tool_calls(
|
||||||
|
provider, current_participant.model_identifier,
|
||||||
|
context, search_service
|
||||||
|
)
|
||||||
|
# Merge tool-based evidence with auto evidence
|
||||||
|
if tool_evidence:
|
||||||
|
if search_evidence:
|
||||||
|
search_evidence.results.extend(tool_evidence.results)
|
||||||
|
search_evidence.query += f" | {tool_evidence.query}"
|
||||||
|
search_evidence.mode = "both"
|
||||||
|
else:
|
||||||
|
search_evidence = tool_evidence
|
||||||
|
else:
|
||||||
|
response = await provider.generate_response(
|
||||||
|
model=current_participant.model_identifier,
|
||||||
|
prompt=context
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clean the response to remove any echoed prompt/meta text
|
||||||
|
response = self._clean_response(response)
|
||||||
|
|
||||||
|
# Create a new round
|
||||||
|
round_data = DebateRound(
|
||||||
|
round_number=round_num + 1,
|
||||||
|
speaker=current_participant.model_identifier,
|
||||||
|
stance=current_participant.stance,
|
||||||
|
content=response,
|
||||||
|
timestamp=datetime.now(),
|
||||||
|
token_count=len(response.split()), # Approximate token count
|
||||||
|
search_evidence=search_evidence
|
||||||
|
)
|
||||||
|
|
||||||
|
session.rounds.append(round_data)
|
||||||
|
|
||||||
|
# Update evidence library with search results from this round
|
||||||
|
if round_data.search_evidence:
|
||||||
|
self._update_evidence_library(session, round_data)
|
||||||
|
|
||||||
|
# Update session in storage
|
||||||
|
await self.session_manager.update_session(self.db, session)
|
||||||
|
|
||||||
|
# Small delay between rounds to simulate realistic interaction
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
# Generate summary after all rounds are complete
|
||||||
|
summary = await summarize_debate(session)
|
||||||
|
session.summary = summary
|
||||||
|
session.status = "completed"
|
||||||
|
session.completed_at = datetime.now()
|
||||||
|
|
||||||
|
await self.session_manager.update_session(self.db, session)
|
||||||
|
return session
|
||||||
|
|
||||||
|
def _perform_automatic_search(
|
||||||
|
self, search_service: SearchService, session: DebateSession, round_num: int
|
||||||
|
) -> Optional[SearchEvidence]:
|
||||||
|
"""
|
||||||
|
Perform an automatic web search based on the topic and last opponent argument.
|
||||||
|
"""
|
||||||
|
last_opponent_arg = None
|
||||||
|
if session.rounds:
|
||||||
|
last_opponent_arg = session.rounds[-1].content
|
||||||
|
|
||||||
|
query = SearchService.generate_search_query(session.topic, last_opponent_arg)
|
||||||
|
results = search_service.search(query, max_results=3)
|
||||||
|
|
||||||
|
if results:
|
||||||
|
return SearchEvidence(
|
||||||
|
query=query,
|
||||||
|
results=results,
|
||||||
|
mode="auto"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _handle_tool_calls(
|
||||||
|
self, provider, model: str, context: str, search_service: SearchService
|
||||||
|
) -> tuple:
|
||||||
|
"""
|
||||||
|
Handle tool calling flow: send prompt with tools, execute any tool calls,
|
||||||
|
then re-prompt with results. Max 2 tool call iterations.
|
||||||
|
Returns (final_response_text, SearchEvidence_or_None).
|
||||||
|
"""
|
||||||
|
tools = [SearchService.get_tool_definition()]
|
||||||
|
all_search_results = []
|
||||||
|
all_queries = []
|
||||||
|
|
||||||
|
text, tool_calls = await provider.generate_response_with_tools(
|
||||||
|
model=model,
|
||||||
|
prompt=context,
|
||||||
|
tools=tools,
|
||||||
|
max_tokens=500
|
||||||
|
)
|
||||||
|
|
||||||
|
# If no tool calls, return the text response directly
|
||||||
|
if not tool_calls:
|
||||||
|
return text, None
|
||||||
|
|
||||||
|
# Process up to 2 rounds of tool calls
|
||||||
|
for iteration in range(2):
|
||||||
|
if not tool_calls:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Execute each tool call
|
||||||
|
tool_results_text = []
|
||||||
|
for tc in tool_calls:
|
||||||
|
if tc["name"] == "web_search":
|
||||||
|
query = tc["arguments"].get("query", "")
|
||||||
|
all_queries.append(query)
|
||||||
|
results = search_service.search(query, max_results=3)
|
||||||
|
all_search_results.extend(results)
|
||||||
|
|
||||||
|
# Format results for the model
|
||||||
|
evidence = SearchEvidence(query=query, results=results, mode="tool")
|
||||||
|
tool_results_text.append(SearchService.format_results_for_context(evidence))
|
||||||
|
|
||||||
|
# Re-prompt the model with tool results
|
||||||
|
augmented_context = context + "\n" + "\n".join(tool_results_text)
|
||||||
|
augmented_context += "\n请基于以上搜索结果和辩论历史,给出你的论点。"
|
||||||
|
|
||||||
|
text, tool_calls = await provider.generate_response_with_tools(
|
||||||
|
model=model,
|
||||||
|
prompt=augmented_context,
|
||||||
|
tools=tools,
|
||||||
|
max_tokens=500
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build combined evidence
|
||||||
|
evidence = None
|
||||||
|
if all_search_results:
|
||||||
|
evidence = SearchEvidence(
|
||||||
|
query=" | ".join(all_queries),
|
||||||
|
results=all_search_results,
|
||||||
|
mode="tool"
|
||||||
|
)
|
||||||
|
|
||||||
|
# If we still got no text (model keeps calling tools), fall back
|
||||||
|
if not text:
|
||||||
|
text = await provider.generate_response(model=model, prompt=context, max_tokens=500)
|
||||||
|
|
||||||
|
return text, evidence
|
||||||
|
|
||||||
|
def _update_evidence_library(self, session: DebateSession, round_data: DebateRound):
|
||||||
|
"""
|
||||||
|
Merge search results from a round into the session's evidence library, deduplicating by URL.
|
||||||
|
"""
|
||||||
|
ref = EvidenceReference(
|
||||||
|
round_number=round_data.round_number,
|
||||||
|
speaker=round_data.speaker,
|
||||||
|
stance=round_data.stance
|
||||||
|
)
|
||||||
|
|
||||||
|
url_index = {entry.url: i for i, entry in enumerate(session.evidence_library)}
|
||||||
|
|
||||||
|
for result in round_data.search_evidence.results:
|
||||||
|
if result.url in url_index:
|
||||||
|
entry = session.evidence_library[url_index[result.url]]
|
||||||
|
# Avoid duplicate references (same round + speaker)
|
||||||
|
if not any(
|
||||||
|
r.round_number == ref.round_number and r.speaker == ref.speaker
|
||||||
|
for r in entry.references
|
||||||
|
):
|
||||||
|
entry.references.append(ref)
|
||||||
|
else:
|
||||||
|
new_entry = EvidenceEntry(
|
||||||
|
title=result.title,
|
||||||
|
url=result.url,
|
||||||
|
snippet=result.snippet,
|
||||||
|
score=result.score,
|
||||||
|
references=[ref]
|
||||||
|
)
|
||||||
|
session.evidence_library.append(new_entry)
|
||||||
|
url_index[result.url] = len(session.evidence_library) - 1
|
||||||
|
|
||||||
|
def _prepare_context(
|
||||||
|
self, session: DebateSession, current_stance: DebateStance,
|
||||||
|
search_evidence: Optional[SearchEvidence] = None
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Prepare the context/prompt for the current model turn
|
||||||
|
"""
|
||||||
|
# Determine the stance of the current speaker
|
||||||
|
if current_stance == DebateStance.PRO:
|
||||||
|
position_desc = "正方(支持方)"
|
||||||
|
opposing_desc = "反方(反对方)"
|
||||||
|
else:
|
||||||
|
position_desc = "反方(反对方)"
|
||||||
|
opposing_desc = "正方(支持方)"
|
||||||
|
|
||||||
|
# Build the context with previous rounds
|
||||||
|
context_parts = [
|
||||||
|
f"辩论主题: {session.topic}",
|
||||||
|
f"你的立场: {position_desc}",
|
||||||
|
"辩论规则:",
|
||||||
|
"- 必须回应对方上一轮的核心论点",
|
||||||
|
"- 不得重复自己已提出的观点",
|
||||||
|
"- 输出长度限制在合理范围内",
|
||||||
|
"\n历史辩论记录:"
|
||||||
|
]
|
||||||
|
|
||||||
|
for round_data in session.rounds:
|
||||||
|
stance_text = "正方" if round_data.stance == DebateStance.PRO else "反方"
|
||||||
|
context_parts.append(f"第{round_data.round_number}轮 - {stance_text}: {round_data.content}")
|
||||||
|
|
||||||
|
# Inject search results if available
|
||||||
|
if search_evidence:
|
||||||
|
context_parts.append(SearchService.format_results_for_context(search_evidence))
|
||||||
|
|
||||||
|
context_parts.append(f"\n现在轮到你 ({position_desc}) 发言,请基于以上内容进行回应。注意:直接给出你的论点内容,不要重复上述提示词、辩论规则或历史记录。")
|
||||||
|
|
||||||
|
return "\n".join(context_parts)
|
||||||
|
|
||||||
|
def _clean_response(self, response: str) -> str:
|
||||||
|
"""
|
||||||
|
Clean the model response to remove any echoed prompt/meta text
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
|
||||||
|
# Remove common prompt echoes and meta prefixes
|
||||||
|
patterns_to_remove = [
|
||||||
|
r'^第\d+轮\s*[-::]\s*(正方|反方)\s*[::]?\s*', # 第X轮 - 正方/反方:
|
||||||
|
r'^(正方|反方)\s*[((][^))]*[))]\s*[::]?\s*', # 正方(支持方):
|
||||||
|
r'^(正方|反方)\s*[::]\s*', # 正方: or 反方:
|
||||||
|
r'^我的立场\s*[::]\s*', # 我的立场:
|
||||||
|
r'^回应\s*[::]\s*', # 回应:
|
||||||
|
r'^辩论发言\s*[::]\s*', # 辩论发言:
|
||||||
|
]
|
||||||
|
|
||||||
|
cleaned = response.strip()
|
||||||
|
for pattern in patterns_to_remove:
|
||||||
|
cleaned = re.sub(pattern, '', cleaned, flags=re.MULTILINE)
|
||||||
|
|
||||||
|
return cleaned.strip()
|
||||||
|
|
||||||
|
async def get_session_status(self, session_id: str) -> Optional[DebateSession]:
|
||||||
|
"""
|
||||||
|
Get the current status of a debate session
|
||||||
|
"""
|
||||||
|
return await self.session_manager.get_session(self.db, session_id)
|
||||||
|
|
||||||
|
async def terminate_session(self, session_id: str):
|
||||||
|
"""
|
||||||
|
Terminate a debate session prematurely
|
||||||
|
"""
|
||||||
|
session = await self.session_manager.get_session(self.db, session_id)
|
||||||
|
if session:
|
||||||
|
session.status = "terminated"
|
||||||
|
await self.session_manager.update_session(self.db, session)
|
||||||
0
providers/__init__.py
Normal file
0
providers/__init__.py
Normal file
73
providers/base_provider.py
Normal file
73
providers/base_provider.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional, List, Dict, Any, Tuple
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
|
||||||
|
class LLMProvider(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for LLM providers
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(self, db: Session, api_key: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Initialize the provider with a database session and optional API key
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def generate_response(self, model: str, prompt: str, max_tokens: Optional[int] = None) -> str:
|
||||||
|
"""
|
||||||
|
Generate a response from the LLM
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def supports_tools(self) -> bool:
|
||||||
|
"""
|
||||||
|
Whether this provider supports function/tool calling.
|
||||||
|
Override in subclasses that support it.
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def generate_response_with_tools(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
prompt: str,
|
||||||
|
tools: List[Dict[str, Any]],
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
) -> Tuple[str, List[Dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
Generate a response with tool definitions available.
|
||||||
|
Returns (text_content, tool_calls) where tool_calls is a list of
|
||||||
|
dicts with keys: name, arguments (dict).
|
||||||
|
If no tools are called, tool_calls is empty and text_content has the response.
|
||||||
|
Default implementation falls back to regular generation.
|
||||||
|
"""
|
||||||
|
text = await self.generate_response(model, prompt, max_tokens)
|
||||||
|
return text, []
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderFactory:
|
||||||
|
"""
|
||||||
|
Factory class to create provider instances
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_provider(db: Session, provider_type: str, api_key: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Create a provider instance based on the type
|
||||||
|
"""
|
||||||
|
if provider_type.value == "openai":
|
||||||
|
from app.providers.openai_provider import OpenAIProvider
|
||||||
|
return OpenAIProvider(db, api_key)
|
||||||
|
elif provider_type.value == "claude":
|
||||||
|
from app.providers.claude_provider import ClaudeProvider
|
||||||
|
return ClaudeProvider(db, api_key)
|
||||||
|
elif provider_type.value == "qwen":
|
||||||
|
from app.providers.qwen_provider import QwenProvider
|
||||||
|
return QwenProvider(db, api_key)
|
||||||
|
elif provider_type.value == "deepseek":
|
||||||
|
from app.providers.deepseek_provider import DeepSeekProvider
|
||||||
|
return DeepSeekProvider(db, api_key)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported provider type: {provider_type}")
|
||||||
85
providers/claude_provider.py
Normal file
85
providers/claude_provider.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
import anthropic
|
||||||
|
from typing import Optional, List, Dict, Any, Tuple
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from providers.base_provider import LLMProvider
|
||||||
|
from services.api_key_service import ApiKeyService
|
||||||
|
|
||||||
|
SYSTEM_PROMPT = "你正在参与一场结构化辩论。请按照用户消息中的规则进行辩论,直接给出你的论点,不要重复提示词或历史记录。"
|
||||||
|
|
||||||
|
|
||||||
|
class ClaudeProvider(LLMProvider):
|
||||||
|
"""
|
||||||
|
Anthropic Claude API provider implementation
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db: Session, api_key: Optional[str] = None):
|
||||||
|
if not api_key:
|
||||||
|
api_key = ApiKeyService.get_api_key(db, "claude")
|
||||||
|
|
||||||
|
if api_key:
|
||||||
|
self.client = anthropic.AsyncAnthropic(api_key=api_key)
|
||||||
|
else:
|
||||||
|
raise ValueError("Claude API key not found in database or provided")
|
||||||
|
|
||||||
|
def supports_tools(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def generate_response(self, model: str, prompt: str, max_tokens: Optional[int] = None) -> str:
|
||||||
|
try:
|
||||||
|
response = await self.client.messages.create(
|
||||||
|
model=model,
|
||||||
|
max_tokens=max_tokens or 500,
|
||||||
|
temperature=0.7,
|
||||||
|
system=SYSTEM_PROMPT,
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": prompt}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return response.content[0].text
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Error calling Claude API: {str(e)}")
|
||||||
|
|
||||||
|
async def generate_response_with_tools(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
prompt: str,
|
||||||
|
tools: List[Dict[str, Any]],
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
) -> Tuple[str, List[Dict[str, Any]]]:
|
||||||
|
try:
|
||||||
|
# Convert OpenAI-format tools to Anthropic format
|
||||||
|
anthropic_tools = []
|
||||||
|
for tool in tools:
|
||||||
|
func = tool.get("function", tool)
|
||||||
|
anthropic_tools.append({
|
||||||
|
"name": func["name"],
|
||||||
|
"description": func.get("description", ""),
|
||||||
|
"input_schema": func.get("parameters", func.get("input_schema", {}))
|
||||||
|
})
|
||||||
|
|
||||||
|
response = await self.client.messages.create(
|
||||||
|
model=model,
|
||||||
|
max_tokens=max_tokens or 500,
|
||||||
|
temperature=0.7,
|
||||||
|
system=SYSTEM_PROMPT,
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": prompt}
|
||||||
|
],
|
||||||
|
tools=anthropic_tools
|
||||||
|
)
|
||||||
|
|
||||||
|
text_content = ""
|
||||||
|
tool_calls = []
|
||||||
|
for block in response.content:
|
||||||
|
if block.type == "text":
|
||||||
|
text_content += block.text
|
||||||
|
elif block.type == "tool_use":
|
||||||
|
tool_calls.append({
|
||||||
|
"name": block.name,
|
||||||
|
"arguments": block.input
|
||||||
|
})
|
||||||
|
|
||||||
|
return text_content.strip(), tool_calls
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Error calling Claude API with tools: {str(e)}")
|
||||||
64
providers/deepseek_provider.py
Normal file
64
providers/deepseek_provider.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
from typing import Optional
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
from providers.base_provider import LLMProvider
|
||||||
|
from services.api_key_service import ApiKeyService
|
||||||
|
|
||||||
|
|
||||||
|
class DeepSeekProvider(LLMProvider):
|
||||||
|
"""
|
||||||
|
DeepSeek API provider implementation using OpenAI-compatible API
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db: Session, api_key: Optional[str] = None):
|
||||||
|
if not api_key:
|
||||||
|
api_key = ApiKeyService.get_api_key(db, "deepseek")
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
|
raise ValueError("DeepSeek API key not found in database or provided")
|
||||||
|
|
||||||
|
self.client = AsyncOpenAI(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url="https://api.deepseek.com"
|
||||||
|
)
|
||||||
|
|
||||||
|
def supports_tools(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def generate_response(self, model: str, prompt: str, max_tokens: Optional[int] = None) -> str:
|
||||||
|
try:
|
||||||
|
is_reasoner = "reasoner" in model or "r1" in model.lower()
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": prompt}]
|
||||||
|
# deepseek-reasoner 不支持 system message,把指令放进 user message
|
||||||
|
if not is_reasoner:
|
||||||
|
messages.insert(0, {
|
||||||
|
"role": "system",
|
||||||
|
"content": "你正在参与一场结构化辩论。请按照用户消息中的规则进行辩论,直接给出你的论点,不要重复提示词或历史记录。"
|
||||||
|
})
|
||||||
|
|
||||||
|
kwargs = {
|
||||||
|
"model": model,
|
||||||
|
"messages": messages,
|
||||||
|
}
|
||||||
|
# reasoner 模型不支持 max_tokens,使用 max_completion_tokens
|
||||||
|
if is_reasoner:
|
||||||
|
kwargs["max_completion_tokens"] = max_tokens or 4096
|
||||||
|
else:
|
||||||
|
kwargs["max_tokens"] = max_tokens or 500
|
||||||
|
|
||||||
|
response = await self.client.chat.completions.create(**kwargs)
|
||||||
|
|
||||||
|
message = response.choices[0].message
|
||||||
|
content = message.content or ""
|
||||||
|
|
||||||
|
# deepseek-reasoner 的主要内容可能在 reasoning_content 中
|
||||||
|
if not content.strip() and is_reasoner:
|
||||||
|
reasoning = getattr(message, "reasoning_content", None)
|
||||||
|
if reasoning:
|
||||||
|
content = reasoning
|
||||||
|
|
||||||
|
return content.strip()
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Error calling DeepSeek API: {str(e)}")
|
||||||
72
providers/openai_provider.py
Normal file
72
providers/openai_provider.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
import json
|
||||||
|
import openai
|
||||||
|
from typing import Optional, List, Dict, Any, Tuple
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from providers.base_provider import LLMProvider
|
||||||
|
from services.api_key_service import ApiKeyService
|
||||||
|
|
||||||
|
SYSTEM_PROMPT = "你正在参与一场结构化辩论。请按照用户消息中的规则进行辩论,直接给出你的论点,不要重复提示词或历史记录。"
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIProvider(LLMProvider):
|
||||||
|
"""
|
||||||
|
OpenAI API provider implementation
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db: Session, api_key: Optional[str] = None):
|
||||||
|
if not api_key:
|
||||||
|
api_key = ApiKeyService.get_api_key(db, "openai")
|
||||||
|
|
||||||
|
if api_key:
|
||||||
|
self.client = openai.AsyncOpenAI(api_key=api_key)
|
||||||
|
else:
|
||||||
|
raise ValueError("OpenAI API key not found in database or provided")
|
||||||
|
|
||||||
|
def supports_tools(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def generate_response(self, model: str, prompt: str, max_tokens: Optional[int] = None) -> str:
|
||||||
|
try:
|
||||||
|
response = await self.client.chat.completions.create(
|
||||||
|
model=model,
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": SYSTEM_PROMPT},
|
||||||
|
{"role": "user", "content": prompt}
|
||||||
|
],
|
||||||
|
max_tokens=max_tokens or 500
|
||||||
|
)
|
||||||
|
return response.choices[0].message.content.strip()
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Error calling OpenAI API: {str(e)}")
|
||||||
|
|
||||||
|
async def generate_response_with_tools(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
prompt: str,
|
||||||
|
tools: List[Dict[str, Any]],
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
) -> Tuple[str, List[Dict[str, Any]]]:
|
||||||
|
try:
|
||||||
|
response = await self.client.chat.completions.create(
|
||||||
|
model=model,
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": SYSTEM_PROMPT},
|
||||||
|
{"role": "user", "content": prompt}
|
||||||
|
],
|
||||||
|
tools=tools,
|
||||||
|
tool_choice="auto",
|
||||||
|
max_tokens=max_tokens or 500
|
||||||
|
)
|
||||||
|
message = response.choices[0].message
|
||||||
|
text_content = message.content or ""
|
||||||
|
tool_calls = []
|
||||||
|
if message.tool_calls:
|
||||||
|
for tc in message.tool_calls:
|
||||||
|
tool_calls.append({
|
||||||
|
"name": tc.function.name,
|
||||||
|
"arguments": json.loads(tc.function.arguments)
|
||||||
|
})
|
||||||
|
return text_content.strip(), tool_calls
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Error calling OpenAI API with tools: {str(e)}")
|
||||||
49
providers/provider_factory.py
Normal file
49
providers/provider_factory.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
|
||||||
|
class LLMProvider(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for LLM providers
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(self, db: Session, api_key: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Initialize the provider with a database session and optional API key
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def generate_response(self, model: str, prompt: str, max_tokens: Optional[int] = None) -> str:
|
||||||
|
"""
|
||||||
|
Generate a response from the LLM
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderFactory:
|
||||||
|
"""
|
||||||
|
Factory class to create provider instances
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_provider(db: Session, provider_type: str, api_key: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Create a provider instance based on the type
|
||||||
|
"""
|
||||||
|
if provider_type.value == "openai":
|
||||||
|
from providers.openai_provider import OpenAIProvider
|
||||||
|
return OpenAIProvider(db, api_key)
|
||||||
|
elif provider_type.value == "claude":
|
||||||
|
from providers.claude_provider import ClaudeProvider
|
||||||
|
return ClaudeProvider(db, api_key)
|
||||||
|
elif provider_type.value == "qwen":
|
||||||
|
from providers.qwen_provider import QwenProvider
|
||||||
|
return QwenProvider(db, api_key)
|
||||||
|
elif provider_type.value == "deepseek":
|
||||||
|
from providers.deepseek_provider import DeepSeekProvider
|
||||||
|
return DeepSeekProvider(db, api_key)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported provider type: {provider_type}")
|
||||||
75
providers/qwen_provider.py
Normal file
75
providers/qwen_provider.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
import json
|
||||||
|
from typing import Optional, List, Dict, Any, Tuple
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
from providers.base_provider import LLMProvider
|
||||||
|
from services.api_key_service import ApiKeyService
|
||||||
|
|
||||||
|
SYSTEM_PROMPT = "你正在参与一场结构化辩论。请按照用户消息中的规则进行辩论,直接给出你的论点,不要重复提示词或历史记录。"
|
||||||
|
|
||||||
|
|
||||||
|
class QwenProvider(LLMProvider):
|
||||||
|
"""
|
||||||
|
Qwen API provider implementation using DashScope OpenAI-compatible API
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db: Session, api_key: Optional[str] = None):
|
||||||
|
if not api_key:
|
||||||
|
api_key = ApiKeyService.get_api_key(db, "qwen")
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
|
raise ValueError("Qwen API key not found in database or provided")
|
||||||
|
|
||||||
|
self.client = AsyncOpenAI(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
def supports_tools(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def generate_response(self, model: str, prompt: str, max_tokens: Optional[int] = None) -> str:
|
||||||
|
try:
|
||||||
|
response = await self.client.chat.completions.create(
|
||||||
|
model=model,
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": SYSTEM_PROMPT},
|
||||||
|
{"role": "user", "content": prompt}
|
||||||
|
],
|
||||||
|
max_tokens=max_tokens or 500
|
||||||
|
)
|
||||||
|
return response.choices[0].message.content.strip()
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Error calling Qwen API: {str(e)}")
|
||||||
|
|
||||||
|
async def generate_response_with_tools(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
prompt: str,
|
||||||
|
tools: List[Dict[str, Any]],
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
) -> Tuple[str, List[Dict[str, Any]]]:
|
||||||
|
try:
|
||||||
|
response = await self.client.chat.completions.create(
|
||||||
|
model=model,
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": SYSTEM_PROMPT},
|
||||||
|
{"role": "user", "content": prompt}
|
||||||
|
],
|
||||||
|
tools=tools,
|
||||||
|
tool_choice="auto",
|
||||||
|
max_tokens=max_tokens or 500
|
||||||
|
)
|
||||||
|
message = response.choices[0].message
|
||||||
|
text_content = message.content or ""
|
||||||
|
tool_calls = []
|
||||||
|
if message.tool_calls:
|
||||||
|
for tc in message.tool_calls:
|
||||||
|
tool_calls.append({
|
||||||
|
"name": tc.function.name,
|
||||||
|
"arguments": json.loads(tc.function.arguments)
|
||||||
|
})
|
||||||
|
return text_content.strip(), tool_calls
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Error calling Qwen API with tools: {str(e)}")
|
||||||
19
requirements.txt
Normal file
19
requirements.txt
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
fastapi==0.115.0
|
||||||
|
uvicorn[standard]==0.32.0
|
||||||
|
pydantic==2.9.2
|
||||||
|
pydantic-settings==2.6.1
|
||||||
|
sqlalchemy==2.0.35
|
||||||
|
aiosqlite==0.20.0
|
||||||
|
pymysql==1.1.1
|
||||||
|
cryptography==43.0.1
|
||||||
|
python-multipart==0.0.20
|
||||||
|
sse-starlette==2.1.3
|
||||||
|
openai==1.52.2
|
||||||
|
anthropic==0.40.0
|
||||||
|
tiktoken==0.8.0
|
||||||
|
python-dotenv==1.0.1
|
||||||
|
aiohttp==3.9.0
|
||||||
|
httpx==0.27.2
|
||||||
|
tavily-python==0.5.0
|
||||||
|
pyyaml==6.0.2
|
||||||
|
python-jose[cryptography]==3.3.0
|
||||||
0
services/__init__.py
Normal file
0
services/__init__.py
Normal file
71
services/api_key_service.py
Normal file
71
services/api_key_service.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
from cryptography.fernet import Fernet, InvalidToken
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from db_models import ApiKey
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Initialize the encryption key from environment or generate a new one
|
||||||
|
ENCRYPTION_KEY = os.getenv("ENCRYPTION_KEY", Fernet.generate_key().decode())
|
||||||
|
cipher_suite = Fernet(ENCRYPTION_KEY.encode())
|
||||||
|
|
||||||
|
|
||||||
|
class ApiKeyService:
|
||||||
|
"""
|
||||||
|
Service for managing API keys in the database
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def encrypt_api_key(api_key: str) -> str:
|
||||||
|
"""
|
||||||
|
Encrypt an API key
|
||||||
|
"""
|
||||||
|
encrypted_key = cipher_suite.encrypt(api_key.encode())
|
||||||
|
return encrypted_key.decode()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def decrypt_api_key(encrypted_api_key: str) -> str:
|
||||||
|
"""
|
||||||
|
Decrypt an API key
|
||||||
|
"""
|
||||||
|
decrypted_key = cipher_suite.decrypt(encrypted_api_key.encode())
|
||||||
|
return decrypted_key.decode()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_api_key(db: Session, provider: str) -> str:
|
||||||
|
"""
|
||||||
|
Retrieve and decrypt an API key for a provider
|
||||||
|
"""
|
||||||
|
api_key_record = db.query(ApiKey).filter(ApiKey.provider == provider).first()
|
||||||
|
if not api_key_record or not api_key_record.api_key_encrypted:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return ApiKeyService.decrypt_api_key(api_key_record.api_key_encrypted)
|
||||||
|
except InvalidToken:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_api_key(db: Session, provider: str, api_key: str) -> bool:
|
||||||
|
"""
|
||||||
|
Encrypt and store an API key for a provider
|
||||||
|
"""
|
||||||
|
encrypted_key = ApiKeyService.encrypt_api_key(api_key)
|
||||||
|
|
||||||
|
# Check if record exists
|
||||||
|
api_key_record = db.query(ApiKey).filter(ApiKey.provider == provider).first()
|
||||||
|
|
||||||
|
if api_key_record:
|
||||||
|
# Update existing record
|
||||||
|
api_key_record.api_key_encrypted = encrypted_key
|
||||||
|
else:
|
||||||
|
# Create new record
|
||||||
|
api_key_record = ApiKey(
|
||||||
|
provider=provider,
|
||||||
|
api_key_encrypted=encrypted_key
|
||||||
|
)
|
||||||
|
db.add(api_key_record)
|
||||||
|
|
||||||
|
try:
|
||||||
|
db.commit()
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
db.rollback()
|
||||||
|
return False
|
||||||
100
services/config_service.py
Normal file
100
services/config_service.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from urllib.parse import quote_plus
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from cryptography.fernet import Fernet, InvalidToken
|
||||||
|
|
||||||
|
CONFIG_PATH = Path(os.getenv("CONFIG_PATH", "/app/config/dialectica.yaml"))
|
||||||
|
|
||||||
|
# Reuse the same encryption key used for API keys
|
||||||
|
_ENCRYPTION_KEY = os.getenv("ENCRYPTION_KEY", "")
|
||||||
|
_cipher = Fernet(_ENCRYPTION_KEY.encode()) if _ENCRYPTION_KEY else None
|
||||||
|
|
||||||
|
# Fields that should be encrypted in the YAML file
|
||||||
|
_SECRET_FIELDS = {"password"}
|
||||||
|
|
||||||
|
|
||||||
|
def _encrypt(value: str) -> str:
|
||||||
|
if not _cipher or not value:
|
||||||
|
return value
|
||||||
|
return "ENC:" + _cipher.encrypt(value.encode()).decode()
|
||||||
|
|
||||||
|
|
||||||
|
def _decrypt(value: str) -> str:
|
||||||
|
if not _cipher or not isinstance(value, str) or not value.startswith("ENC:"):
|
||||||
|
return value
|
||||||
|
try:
|
||||||
|
return _cipher.decrypt(value[4:].encode()).decode()
|
||||||
|
except InvalidToken:
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def _encrypt_secrets(data: dict) -> dict:
|
||||||
|
"""Deep-copy dict, encrypting secret fields."""
|
||||||
|
out = {}
|
||||||
|
for k, v in data.items():
|
||||||
|
if isinstance(v, dict):
|
||||||
|
out[k] = _encrypt_secrets(v)
|
||||||
|
elif k in _SECRET_FIELDS and isinstance(v, str) and not v.startswith("ENC:"):
|
||||||
|
out[k] = _encrypt(v)
|
||||||
|
else:
|
||||||
|
out[k] = v
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _decrypt_secrets(data: dict) -> dict:
|
||||||
|
"""Deep-copy dict, decrypting secret fields."""
|
||||||
|
out = {}
|
||||||
|
for k, v in data.items():
|
||||||
|
if isinstance(v, dict):
|
||||||
|
out[k] = _decrypt_secrets(v)
|
||||||
|
elif k in _SECRET_FIELDS and isinstance(v, str):
|
||||||
|
out[k] = _decrypt(v)
|
||||||
|
else:
|
||||||
|
out[k] = v
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigService:
|
||||||
|
"""Read / write config/dialectica.yaml."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load() -> dict:
|
||||||
|
"""Load config, returning decrypted values. Empty dict if file missing."""
|
||||||
|
if not CONFIG_PATH.exists():
|
||||||
|
return {}
|
||||||
|
with open(CONFIG_PATH) as f:
|
||||||
|
raw = yaml.safe_load(f) or {}
|
||||||
|
return _decrypt_secrets(raw)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def save(config: dict):
|
||||||
|
"""Save config, encrypting secret fields."""
|
||||||
|
CONFIG_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
encrypted = _encrypt_secrets(config)
|
||||||
|
with open(CONFIG_PATH, "w") as f:
|
||||||
|
yaml.dump(encrypted, f, default_flow_style=False, allow_unicode=True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_db_configured() -> bool:
|
||||||
|
config = ConfigService.load()
|
||||||
|
db = config.get("database", {})
|
||||||
|
return bool(db.get("host") and db.get("database"))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_database_url() -> str | None:
|
||||||
|
config = ConfigService.load()
|
||||||
|
db = config.get("database", {})
|
||||||
|
if not (db.get("host") and db.get("database")):
|
||||||
|
return None
|
||||||
|
user = db.get("user", "root")
|
||||||
|
password = db.get("password", "")
|
||||||
|
host = db["host"]
|
||||||
|
port = db.get("port", 3306)
|
||||||
|
database = db["database"]
|
||||||
|
return f"mysql+pymysql://{quote_plus(user)}:{quote_plus(password)}@{host}:{port}/{database}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_initialized() -> bool:
|
||||||
|
return ConfigService.load().get("initialized", False)
|
||||||
104
services/search_service.py
Normal file
104
services/search_service.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
from models.debate import SearchResult, SearchEvidence
|
||||||
|
|
||||||
|
|
||||||
|
class SearchService:
|
||||||
|
"""
|
||||||
|
Tavily web search wrapper for debate research.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, api_key: str):
|
||||||
|
from tavily import TavilyClient
|
||||||
|
self.client = TavilyClient(api_key=api_key)
|
||||||
|
|
||||||
|
def search(self, query: str, max_results: int = 5) -> List[SearchResult]:
|
||||||
|
"""
|
||||||
|
Perform a web search using Tavily and return structured results.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = self.client.search(
|
||||||
|
query=query,
|
||||||
|
max_results=max_results,
|
||||||
|
search_depth="basic"
|
||||||
|
)
|
||||||
|
results = []
|
||||||
|
for item in response.get("results", []):
|
||||||
|
results.append(SearchResult(
|
||||||
|
title=item.get("title", ""),
|
||||||
|
url=item.get("url", ""),
|
||||||
|
snippet=item.get("content", "")[:500], # Truncate to avoid token bloat
|
||||||
|
score=item.get("score")
|
||||||
|
))
|
||||||
|
return results
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Tavily search error: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def format_results_for_context(evidence: SearchEvidence) -> str:
|
||||||
|
"""
|
||||||
|
Format search results into a string suitable for injecting into the LLM context.
|
||||||
|
"""
|
||||||
|
if not evidence or not evidence.results:
|
||||||
|
return ""
|
||||||
|
lines = [f"\n[网络搜索结果] 搜索词: \"{evidence.query}\""]
|
||||||
|
for i, r in enumerate(evidence.results, 1):
|
||||||
|
lines.append(f" {i}. {r.title}")
|
||||||
|
lines.append(f" {r.snippet}")
|
||||||
|
lines.append(f" 来源: {r.url}")
|
||||||
|
lines.append("[搜索结果结束]\n")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_search_query(topic: str, last_opponent_argument: Optional[str] = None) -> str:
|
||||||
|
"""
|
||||||
|
Generate a search query from the debate topic and the opponent's last argument.
|
||||||
|
"""
|
||||||
|
if last_opponent_argument:
|
||||||
|
# Extract key phrases from the opponent's argument (first ~100 chars)
|
||||||
|
snippet = last_opponent_argument[:100].strip()
|
||||||
|
return f"{topic} {snippet}"
|
||||||
|
return topic
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tool_definition() -> dict:
|
||||||
|
"""
|
||||||
|
Return the web_search tool definition for function calling.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "web_search",
|
||||||
|
"description": "Search the web for current information relevant to the debate topic. Use this to find facts, statistics, or recent news that support your argument.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The search query to look up"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["query"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tool_definition_anthropic() -> dict:
|
||||||
|
"""
|
||||||
|
Return the web_search tool definition in Anthropic format.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"name": "web_search",
|
||||||
|
"description": "Search the web for current information relevant to the debate topic. Use this to find facts, statistics, or recent news that support your argument.",
|
||||||
|
"input_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The search query to look up"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["query"]
|
||||||
|
}
|
||||||
|
}
|
||||||
0
storage/__init__.py
Normal file
0
storage/__init__.py
Normal file
109
storage/database.py
Normal file
109
storage/database.py
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
from sqlalchemy import Column, Integer, String, DateTime, Text
|
||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from datetime import datetime
|
||||||
|
import json
|
||||||
|
|
||||||
|
from models.debate import DebateSession
|
||||||
|
from exceptions import ServiceNotConfiguredError
|
||||||
|
from services.config_service import ConfigService
|
||||||
|
|
||||||
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
|
class DebateSessionDB(Base):
|
||||||
|
__tablename__ = "debate_sessions"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
|
session_id = Column(String(255), unique=True, index=True, nullable=False)
|
||||||
|
topic = Column(Text, nullable=False)
|
||||||
|
participants = Column(Text, nullable=False) # JSON string
|
||||||
|
constraints = Column(Text, nullable=False) # JSON string
|
||||||
|
rounds = Column(Text, nullable=False) # JSON string
|
||||||
|
status = Column(String(50), nullable=False)
|
||||||
|
created_at = Column(DateTime, nullable=False)
|
||||||
|
completed_at = Column(DateTime, nullable=True)
|
||||||
|
summary = Column(Text, nullable=True)
|
||||||
|
evidence_library = Column(Text, nullable=True) # JSON string
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Lazy engine / session factory
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
_engine = None
|
||||||
|
_SessionLocal = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_engine():
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
global _engine
|
||||||
|
if _engine is None:
|
||||||
|
db_url = ConfigService.get_database_url()
|
||||||
|
if not db_url:
|
||||||
|
raise ServiceNotConfiguredError("数据库未配置")
|
||||||
|
_engine = create_engine(db_url)
|
||||||
|
return _engine
|
||||||
|
|
||||||
|
|
||||||
|
def _get_session_factory():
|
||||||
|
global _SessionLocal
|
||||||
|
if _SessionLocal is None:
|
||||||
|
_SessionLocal = sessionmaker(
|
||||||
|
autocommit=False, autoflush=False, bind=_get_engine()
|
||||||
|
)
|
||||||
|
return _SessionLocal
|
||||||
|
|
||||||
|
|
||||||
|
def init_db():
|
||||||
|
"""Create all tables if DB is configured; silently skip otherwise."""
|
||||||
|
if not ConfigService.is_db_configured():
|
||||||
|
print("WARNING: Database not configured, skipping table creation.")
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
from db_models import Base as ApiBase
|
||||||
|
engine = _get_engine()
|
||||||
|
Base.metadata.create_all(bind=engine)
|
||||||
|
ApiBase.metadata.create_all(bind=engine)
|
||||||
|
except Exception as e:
|
||||||
|
global _engine, _SessionLocal
|
||||||
|
print(f"WARNING: Database connection failed, skipping table creation: {e}")
|
||||||
|
if _engine is not None:
|
||||||
|
_engine.dispose()
|
||||||
|
_engine = None
|
||||||
|
_SessionLocal = None
|
||||||
|
|
||||||
|
|
||||||
|
def debate_session_from_db(db_session) -> DebateSession:
|
||||||
|
"""Convert database session to Pydantic model."""
|
||||||
|
evidence_library = []
|
||||||
|
if db_session.evidence_library:
|
||||||
|
evidence_library = json.loads(db_session.evidence_library)
|
||||||
|
|
||||||
|
return DebateSession(
|
||||||
|
session_id=db_session.session_id,
|
||||||
|
topic=db_session.topic,
|
||||||
|
participants=json.loads(db_session.participants),
|
||||||
|
constraints=json.loads(db_session.constraints),
|
||||||
|
rounds=json.loads(db_session.rounds),
|
||||||
|
status=db_session.status,
|
||||||
|
created_at=db_session.created_at,
|
||||||
|
completed_at=db_session.completed_at,
|
||||||
|
summary=db_session.summary,
|
||||||
|
evidence_library=evidence_library
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def debate_session_to_db(session: DebateSession) -> DebateSessionDB:
|
||||||
|
"""Convert Pydantic model to database model."""
|
||||||
|
return DebateSessionDB(
|
||||||
|
session_id=session.session_id,
|
||||||
|
topic=session.topic,
|
||||||
|
participants=json.dumps([p.dict() for p in session.participants]),
|
||||||
|
constraints=json.dumps(session.constraints.dict()),
|
||||||
|
rounds=json.dumps([r.dict() for r in session.rounds], default=str),
|
||||||
|
status=session.status,
|
||||||
|
created_at=session.created_at,
|
||||||
|
completed_at=session.completed_at,
|
||||||
|
summary=session.summary,
|
||||||
|
evidence_library=json.dumps([e.dict() for e in session.evidence_library], default=str) if session.evidence_library else None
|
||||||
|
)
|
||||||
67
storage/session_manager.py
Normal file
67
storage/session_manager.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from datetime import datetime
|
||||||
|
import json
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from models.debate import DebateSession
|
||||||
|
from storage.database import DebateSessionDB, debate_session_from_db, debate_session_to_db
|
||||||
|
|
||||||
|
|
||||||
|
class SessionManager:
|
||||||
|
"""
|
||||||
|
Manages debate sessions in storage
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def save_session(self, db: Session, session: DebateSession):
|
||||||
|
"""
|
||||||
|
Save a debate session to the database
|
||||||
|
"""
|
||||||
|
db_session = debate_session_to_db(session)
|
||||||
|
db.add(db_session)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_session)
|
||||||
|
|
||||||
|
async def get_session(self, db: Session, session_id: str) -> Optional[DebateSession]:
|
||||||
|
"""
|
||||||
|
Retrieve a debate session from the database
|
||||||
|
"""
|
||||||
|
db_session = db.query(DebateSessionDB).filter(DebateSessionDB.session_id == session_id).first()
|
||||||
|
if not db_session:
|
||||||
|
return None
|
||||||
|
return debate_session_from_db(db_session)
|
||||||
|
|
||||||
|
async def update_session(self, db: Session, session: DebateSession):
|
||||||
|
"""
|
||||||
|
Update an existing debate session in the database
|
||||||
|
"""
|
||||||
|
db_session = db.query(DebateSessionDB).filter(DebateSessionDB.session_id == session.session_id).first()
|
||||||
|
if db_session:
|
||||||
|
db_session.topic = session.topic
|
||||||
|
db_session.participants = json.dumps([p.dict() for p in session.participants])
|
||||||
|
db_session.constraints = json.dumps(session.constraints.dict())
|
||||||
|
db_session.rounds = json.dumps([r.dict() for r in session.rounds], default=str)
|
||||||
|
db_session.status = session.status
|
||||||
|
db_session.completed_at = session.completed_at
|
||||||
|
db_session.summary = session.summary
|
||||||
|
db_session.evidence_library = json.dumps([e.dict() for e in session.evidence_library], default=str) if session.evidence_library else None
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_session)
|
||||||
|
|
||||||
|
async def list_sessions(self, db: Session):
|
||||||
|
"""
|
||||||
|
List all debate sessions
|
||||||
|
"""
|
||||||
|
db_sessions = db.query(DebateSessionDB).all()
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"session_id": db_session.session_id,
|
||||||
|
"topic": db_session.topic,
|
||||||
|
"status": db_session.status,
|
||||||
|
"created_at": db_session.created_at.isoformat() if db_session.created_at else None
|
||||||
|
}
|
||||||
|
for db_session in db_sessions
|
||||||
|
]
|
||||||
0
utils/__init__.py
Normal file
0
utils/__init__.py
Normal file
39
utils/summarizer.py
Normal file
39
utils/summarizer.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
async def summarize_debate(session):
|
||||||
|
"""
|
||||||
|
Generate a summary of the debate session
|
||||||
|
"""
|
||||||
|
if not session.rounds:
|
||||||
|
return "No rounds were completed in this debate."
|
||||||
|
|
||||||
|
# Extract key points from each side
|
||||||
|
pro_points = []
|
||||||
|
con_points = []
|
||||||
|
|
||||||
|
for round_data in session.rounds:
|
||||||
|
if round_data.stance.value == "pro":
|
||||||
|
pro_points.append(round_data.content)
|
||||||
|
else:
|
||||||
|
con_points.append(round_data.content)
|
||||||
|
|
||||||
|
# Create a summary
|
||||||
|
summary_parts = [
|
||||||
|
f"辩论主题: {session.topic}",
|
||||||
|
"",
|
||||||
|
"正方主要观点:",
|
||||||
|
]
|
||||||
|
|
||||||
|
for i, point in enumerate(pro_points, 1):
|
||||||
|
summary_parts.append(f"{i}. {point[:100]}...") # Truncate for brevity
|
||||||
|
|
||||||
|
summary_parts.append("")
|
||||||
|
summary_parts.append("反方主要观点:")
|
||||||
|
|
||||||
|
for i, point in enumerate(con_points, 1):
|
||||||
|
summary_parts.append(f"{i}. {point[:100]}...") # Truncate for brevity
|
||||||
|
|
||||||
|
summary_parts.append("")
|
||||||
|
summary_parts.append("总结: 本次辩论完成了 {} 轮,双方就 '{}' 主题进行了充分的讨论。".format(
|
||||||
|
len(session.rounds), session.topic
|
||||||
|
))
|
||||||
|
|
||||||
|
return "\n".join(summary_parts)
|
||||||
Reference in New Issue
Block a user