From 343a4b8d673ac0d869884f1da393d2210a6ea60f Mon Sep 17 00:00:00 2001 From: hzhang Date: Thu, 12 Feb 2026 15:45:48 +0000 Subject: [PATCH] init --- .gitignore | 1 + Dockerfile | 16 ++ api/__init__.py | 4 + api/api_keys.py | 113 +++++++++ api/debates.py | 158 ++++++++++++ api/models.py | 173 ++++++++++++++ api/setup.py | 192 +++++++++++++++ app.py | 59 +++++ db_models/__init__.py | 79 ++++++ exceptions/__init__.py | 3 + middleware/__init__.py | 0 middleware/auth.py | 114 +++++++++ middleware/config_guard.py | 32 +++ models/__init__.py | 0 models/debate.py | 92 +++++++ orchestrator/__init__.py | 0 orchestrator/debate_orchestrator.py | 358 ++++++++++++++++++++++++++++ providers/__init__.py | 0 providers/base_provider.py | 73 ++++++ providers/claude_provider.py | 85 +++++++ providers/deepseek_provider.py | 64 +++++ providers/openai_provider.py | 72 ++++++ providers/provider_factory.py | 49 ++++ providers/qwen_provider.py | 75 ++++++ requirements.txt | 19 ++ services/__init__.py | 0 services/api_key_service.py | 71 ++++++ services/config_service.py | 100 ++++++++ services/search_service.py | 104 ++++++++ storage/__init__.py | 0 storage/database.py | 109 +++++++++ storage/session_manager.py | 67 ++++++ utils/__init__.py | 0 utils/summarizer.py | 39 +++ 34 files changed, 2321 insertions(+) create mode 100644 .gitignore create mode 100644 Dockerfile create mode 100644 api/__init__.py create mode 100644 api/api_keys.py create mode 100644 api/debates.py create mode 100644 api/models.py create mode 100644 api/setup.py create mode 100644 app.py create mode 100644 db_models/__init__.py create mode 100644 exceptions/__init__.py create mode 100644 middleware/__init__.py create mode 100644 middleware/auth.py create mode 100644 middleware/config_guard.py create mode 100644 models/__init__.py create mode 100644 models/debate.py create mode 100644 orchestrator/__init__.py create mode 100644 orchestrator/debate_orchestrator.py create mode 100644 providers/__init__.py create mode 100644 providers/base_provider.py create mode 100644 providers/claude_provider.py create mode 100644 providers/deepseek_provider.py create mode 100644 providers/openai_provider.py create mode 100644 providers/provider_factory.py create mode 100644 providers/qwen_provider.py create mode 100644 requirements.txt create mode 100644 services/__init__.py create mode 100644 services/api_key_service.py create mode 100644 services/config_service.py create mode 100644 services/search_service.py create mode 100644 storage/__init__.py create mode 100644 storage/database.py create mode 100644 storage/session_manager.py create mode 100644 utils/__init__.py create mode 100644 utils/summarizer.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..85e7c1d --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +/.idea/ diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..025f2ab --- /dev/null +++ b/Dockerfile @@ -0,0 +1,16 @@ +FROM python:3.13-slim + +ENV PYTHONDONTWRITEBYTECODE 1 +ENV PYTHONUNBUFFERED 1 + +WORKDIR /app + +COPY requirements.txt ./requirements.txt + +RUN pip install --no-cache-dir -r requirements.txt + +COPY . /app/ + +EXPOSE 5001 + +CMD ["python3", "app.py"] \ No newline at end of file diff --git a/api/__init__.py b/api/__init__.py new file mode 100644 index 0000000..cdca84d --- /dev/null +++ b/api/__init__.py @@ -0,0 +1,4 @@ +from api.debates import router as debates_router +from api.api_keys import router as api_keys_router +from api.models import router as models_router +from api.setup import router as setup_router diff --git a/api/api_keys.py b/api/api_keys.py new file mode 100644 index 0000000..e325bf6 --- /dev/null +++ b/api/api_keys.py @@ -0,0 +1,113 @@ +from fastapi import APIRouter, Depends, Form, HTTPException + +from middleware.auth import require_auth +import aiohttp + +from services.api_key_service import ApiKeyService +from db_models import get_db + +router = APIRouter(tags=["api-keys"]) + + +async def validate_api_key(provider: str, api_key: str): + """ + Validate an API key by listing models from the provider. + All providers validated the same way: if we can list models, the key is valid. + """ + try: + if provider == "openai": + import openai + async with openai.AsyncOpenAI(api_key=api_key, timeout=10.0) as client: + await client.models.list() + return True + + # Claude, Qwen, DeepSeek: GET their models endpoint via aiohttp + endpoints = { + "claude": { + "url": "https://api.anthropic.com/v1/models", + "headers": { + "x-api-key": api_key, + "anthropic-version": "2023-06-01" + } + }, + "qwen": { + "url": "https://dashscope.aliyuncs.com/compatible-mode/v1/models", + "headers": {"Authorization": f"Bearer {api_key}"} + }, + "deepseek": { + "url": "https://api.deepseek.com/v1/models", + "headers": {"Authorization": f"Bearer {api_key}"} + } + } + + if provider in endpoints: + timeout = aiohttp.ClientTimeout(total=10) + async with aiohttp.ClientSession(timeout=timeout) as session: + ep = endpoints[provider] + async with session.get(ep["url"], headers=ep["headers"]) as response: + return response.status == 200 + + # Tavily: validate by calling REST API directly (no tavily package needed) + if provider == "tavily": + timeout = aiohttp.ClientTimeout(total=10) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.post( + "https://api.tavily.com/search", + json={"api_key": api_key, "query": "test", "max_results": 1} + ) as response: + if response.status == 200: + data = await response.json() + return bool(data.get("results")) + return False + + return False + except Exception: + return False + + +@router.post("/api-keys/{provider}", dependencies=[Depends(require_auth)]) +async def set_api_key(provider: str, api_key: str = Form(...), db=Depends(get_db)): + """ + Set API key for a specific provider + """ + try: + success = ApiKeyService.set_api_key(db, provider, api_key) + if success: + return {"message": f"API key for {provider} updated successfully"} + else: + raise HTTPException(status_code=500, detail="Failed to update API key") + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/api-keys/{provider}") +async def get_api_key(provider: str, db=Depends(get_db)): + """ + Get API key for a specific provider + """ + try: + api_key = ApiKeyService.get_api_key(db, provider) + if api_key: + return {"provider": provider, "api_key": api_key} + else: + raise HTTPException(status_code=404, detail=f"No API key found for {provider}") + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/validate-api-key/{provider}", dependencies=[Depends(require_auth)]) +async def validate_api_key_endpoint(provider: str, api_key: str = Form(...)): + """ + Validate an API key by making a test request to the provider + This endpoint is used by the frontend to validate API keys without CORS issues + """ + try: + is_valid = await validate_api_key(provider, api_key) + if is_valid: + return {"valid": True, "message": f"Valid {provider} API key"} + else: + return {"valid": False, "message": f"Invalid {provider} API key"} + except Exception as e: + return {"valid": False, "message": str(e)} diff --git a/api/debates.py b/api/debates.py new file mode 100644 index 0000000..0825d59 --- /dev/null +++ b/api/debates.py @@ -0,0 +1,158 @@ +from fastapi import APIRouter, Depends, HTTPException + +from middleware.auth import require_auth +from sse_starlette.sse import EventSourceResponse +from typing import Dict, Any +import asyncio +import json + +from orchestrator.debate_orchestrator import DebateOrchestrator +from models.debate import DebateRequest +from storage.session_manager import SessionManager +from db_models import get_db + +router = APIRouter(tags=["debates"]) + + +@router.post("/debate/create", dependencies=[Depends(require_auth)]) +async def create_debate(debate_request: DebateRequest, db=Depends(get_db)) -> Dict[str, Any]: + """ + Create a new debate session with specified parameters + """ + try: + orchestrator = DebateOrchestrator(db) + session_id = await orchestrator.create_session(debate_request) + + return { + "session_id": session_id, + "status": "created", + "message": f"Debate session {session_id} created successfully" + } + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/debate/{session_id}") +async def get_debate(session_id: str, db=Depends(get_db)) -> Dict[str, Any]: + """ + Get the current state of a debate session + """ + try: + orchestrator = DebateOrchestrator(db) + session = await orchestrator.get_session_status(session_id) + + if not session: + raise HTTPException(status_code=404, detail=f"Debate session {session_id} not found") + + return session.dict() + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/debate/{session_id}/start", dependencies=[Depends(require_auth)]) +async def start_debate(session_id: str, db=Depends(get_db)) -> Dict[str, Any]: + """ + Start a debate session and stream the results + """ + try: + orchestrator = DebateOrchestrator(db) + session = await orchestrator.run_debate(session_id) + + return { + "session_id": session_id, + "status": session.status, + "message": f"Debate session {session_id} completed" + } + except Exception as e: + import traceback + traceback.print_exc() + raise HTTPException(status_code=500, detail=str(e)) + + +@router.delete("/debate/{session_id}", dependencies=[Depends(require_auth)]) +async def end_debate(session_id: str, db=Depends(get_db)) -> Dict[str, Any]: + """ + End a debate session prematurely + """ + try: + orchestrator = DebateOrchestrator(db) + await orchestrator.terminate_session(session_id) + + return {"session_id": session_id, "status": "terminated"} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/debate/{session_id}/stream") +async def stream_debate(session_id: str, db=Depends(get_db)) -> EventSourceResponse: + """ + Stream debate updates in real-time using Server-Sent Events + """ + async def event_generator(): + session_manager = SessionManager() + + # Check if the session exists + session = await session_manager.get_session(db, session_id) + if not session: + yield {"event": "error", "data": json.dumps({"error": f"Session {session_id} not found"})} + return + + # Yield initial state + yield {"event": "update", "data": json.dumps({ + "session_id": session_id, + "status": session.status, + "rounds": [round.dict() for round in session.rounds], + "evidence_library": [e.dict() for e in session.evidence_library], + "message": "Debate session initialized" + }, default=str)} + + last_round_count = len(session.rounds) + # Poll until debate completes or times out (max 5 min) + for _ in range(150): # 150 × 2s = 300s timeout + await asyncio.sleep(2) + + # Reset transaction so we see commits from the run_debate request + db.commit() + + updated_session = await session_manager.get_session(db, session_id) + if not updated_session: + break + + # Only yield update when rounds actually changed + if len(updated_session.rounds) != last_round_count or updated_session.status != session.status: + last_round_count = len(updated_session.rounds) + session = updated_session + yield {"event": "update", "data": json.dumps({ + "session_id": session_id, + "status": updated_session.status, + "rounds": [round.dict() for round in updated_session.rounds], + "evidence_library": [e.dict() for e in updated_session.evidence_library], + "current_round": len(updated_session.rounds) + }, default=str)} + + if updated_session.status in ("completed", "terminated"): + yield {"event": "complete", "data": json.dumps({ + "session_id": session_id, + "status": updated_session.status, + "summary": updated_session.summary, + "rounds": [round.dict() for round in updated_session.rounds], + "evidence_library": [e.dict() for e in updated_session.evidence_library] + }, default=str)} + break + + return EventSourceResponse(event_generator()) + + +@router.get("/sessions") +async def list_sessions(db=Depends(get_db)) -> Dict[str, Any]: + """ + List all debate sessions + """ + try: + session_manager = SessionManager() + sessions = await session_manager.list_sessions(db) + return {"sessions": sessions} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) diff --git a/api/models.py b/api/models.py new file mode 100644 index 0000000..0c82db3 --- /dev/null +++ b/api/models.py @@ -0,0 +1,173 @@ +from fastapi import APIRouter, Depends, HTTPException +import aiohttp + +from services.api_key_service import ApiKeyService +from db_models import get_db +from api.api_keys import validate_api_key + +router = APIRouter(tags=["models"]) + + +@router.get("/models/{provider}") +async def get_available_models(provider: str, db=Depends(get_db)): + """ + Get available models for a specific provider by fetching from their API. + Falls back to a curated default list if the API key is missing or the request fails. + """ + # Curated defaults for each provider — used as fallback + default_models = { + "openai": [ + {"model_identifier": "gpt-4o", "display_name": "gpt-4o"}, + {"model_identifier": "gpt-4o-mini", "display_name": "gpt-4o-mini"}, + {"model_identifier": "gpt-4-turbo", "display_name": "gpt-4-turbo"}, + {"model_identifier": "gpt-4", "display_name": "gpt-4"}, + {"model_identifier": "o3-mini", "display_name": "o3-mini"}, + {"model_identifier": "gpt-3.5-turbo", "display_name": "gpt-3.5-turbo"} + ], + "claude": [ + {"model_identifier": "claude-opus-4-5", "display_name": "claude-opus-4-5"}, + {"model_identifier": "claude-sonnet-4-5", "display_name": "claude-sonnet-4-5"}, + {"model_identifier": "claude-3-5-sonnet-20241022", "display_name": "claude-3-5-sonnet-20241022"}, + {"model_identifier": "claude-3-5-haiku-20241022", "display_name": "claude-3-5-haiku-20241022"}, + {"model_identifier": "claude-3-opus-20240229", "display_name": "claude-3-opus-20240229"} + ], + "qwen": [ + {"model_identifier": "qwen3-max", "display_name": "qwen3-max"}, + {"model_identifier": "qwen3-plus", "display_name": "qwen3-plus"}, + {"model_identifier": "qwen3-flash", "display_name": "qwen3-flash"}, + {"model_identifier": "qwen-max", "display_name": "qwen-max"}, + {"model_identifier": "qwen-plus", "display_name": "qwen-plus"}, + {"model_identifier": "qwen-turbo", "display_name": "qwen-turbo"} + ], + "deepseek": [ + {"model_identifier": "deepseek-chat", "display_name": "deepseek-chat"}, + {"model_identifier": "deepseek-reasoner", "display_name": "deepseek-reasoner"}, + {"model_identifier": "deepseek-v3", "display_name": "deepseek-v3"}, + {"model_identifier": "deepseek-r1", "display_name": "deepseek-r1"} + ] + } + + defaults = default_models.get(provider, []) + + try: + # Retrieve and decrypt API key + decrypted_key = ApiKeyService.get_api_key(db, provider) + if not decrypted_key: + return {"provider": provider, "models": defaults} + + # ---------- OpenAI ---------- + if provider == "openai": + import openai + async with openai.AsyncOpenAI(api_key=decrypted_key, timeout=10.0) as client: + response = await client.models.list() + + # Keep only chat / reasoning models, sorted newest-first by created timestamp + chat_prefixes = ('gpt-', 'o1', 'o3', 'o4', 'chatgpt') + models = [] + seen = set() + for m in sorted(response.data, key=lambda x: x.created, reverse=True): + if m.id not in seen and m.id.startswith(chat_prefixes): + seen.add(m.id) + models.append({"model_identifier": m.id, "display_name": m.id}) + + if models: + return {"provider": provider, "models": models} + + # ---------- Claude ---------- + elif provider == "claude": + timeout = aiohttp.ClientTimeout(total=10) + async with aiohttp.ClientSession(timeout=timeout) as session: + headers = { + "x-api-key": decrypted_key, + "anthropic-version": "2023-06-01" + } + async with session.get("https://api.anthropic.com/v1/models", headers=headers) as resp: + if resp.status == 200: + data = await resp.json() + models = [] + for m in data.get("data", []): + model_id = m.get("id", "") + if model_id.startswith("claude-"): + models.append({"model_identifier": model_id, "display_name": model_id}) + if models: + return {"provider": provider, "models": models} + else: + print(f"Claude models API returned {resp.status}: {await resp.text()}") + + # ---------- Qwen ---------- + elif provider == "qwen": + timeout = aiohttp.ClientTimeout(total=10) + async with aiohttp.ClientSession(timeout=timeout) as session: + headers = {"Authorization": f"Bearer {decrypted_key}"} + async with session.get("https://dashscope.aliyuncs.com/compatible-mode/v1/models", headers=headers) as resp: + if resp.status == 200: + data = await resp.json() + exclude_keywords = [ + 'tts', 'vl', 'ocr', 'image', 'asr', '-mt-', '-mt', + 'math', 'embed', 'rerank', 'coder', 'translate', + 's2s', 'deep-search', 'omni', 'gui-' + ] + models = [] + for m in data.get("data", []): + model_id = m.get("id", "") + if not model_id: + continue + if not (model_id.startswith("qwen") or model_id.startswith("qwq")): + continue + if any(kw in model_id for kw in exclude_keywords): + continue + models.append({"model_identifier": model_id, "display_name": model_id}) + if models: + return {"provider": provider, "models": models} + else: + print(f"Qwen models API returned {resp.status}: {await resp.text()}") + + # ---------- DeepSeek ---------- + elif provider == "deepseek": + timeout = aiohttp.ClientTimeout(total=10) + async with aiohttp.ClientSession(timeout=timeout) as session: + headers = {"Authorization": f"Bearer {decrypted_key}"} + async with session.get("https://api.deepseek.com/v1/models", headers=headers) as resp: + if resp.status == 200: + data = await resp.json() + models = [] + for m in data.get("data", []): + model_id = m.get("id", "") + if model_id.startswith("deepseek"): + models.append({"model_identifier": model_id, "display_name": model_id}) + if models: + return {"provider": provider, "models": models} + else: + print(f"DeepSeek models API returned {resp.status}: {await resp.text()}") + + # API fetch succeeded but returned empty list, or unknown provider — use defaults + return {"provider": provider, "models": defaults} + + except Exception as e: + print(f"Error fetching models for {provider}: {e}") + import traceback + traceback.print_exc() + return {"provider": provider, "models": defaults} + + +@router.get("/providers") +async def get_available_providers(db=Depends(get_db)): + """ + Get all providers that have valid API keys set + """ + try: + available_providers = [] + for provider in ("openai", "claude", "qwen", "deepseek", "tavily"): + decrypted_key = ApiKeyService.get_api_key(db, provider) + if not decrypted_key: + continue + is_valid = await validate_api_key(provider, decrypted_key) + if is_valid: + available_providers.append({ + "provider": provider, + "has_valid_key": True + }) + + return {"providers": available_providers} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) diff --git a/api/setup.py b/api/setup.py new file mode 100644 index 0000000..43342ba --- /dev/null +++ b/api/setup.py @@ -0,0 +1,192 @@ +import os +from typing import Optional +from urllib.parse import quote_plus + +from fastapi import APIRouter, Depends, HTTPException, Request +from pydantic import BaseModel + +from services.config_service import ConfigService + +router = APIRouter(prefix="/api/setup", tags=["setup"]) + +PASSWORD_PLACEHOLDER = "********" + + +# --------------------------------------------------------------------------- +# Request / response models +# --------------------------------------------------------------------------- +class DatabaseConfig(BaseModel): + host: str + port: int = 3306 + user: str = "root" + password: str = "" + database: str = "dialectica" + + +class KeycloakConfig(BaseModel): + host: str = "" + realm: str = "" + client_id: str = "" + + +class TlsConfig(BaseModel): + cert_path: str = "" + key_path: str = "" + force_https: bool = False + + +class FullConfig(BaseModel): + database: Optional[DatabaseConfig] = None + keycloak: Optional[KeycloakConfig] = None + tls: Optional[TlsConfig] = None + + +# --------------------------------------------------------------------------- +# Access control dependency +# --------------------------------------------------------------------------- +async def setup_guard(request: Request): + """Three-phase access control for setup routes. + + 1. Not initialized → only localhost allowed + 2. ENV_MODE=dev → open + 3. ENV_MODE=prod → Keycloak admin JWT required + """ + config = ConfigService.load() + initialized = config.get("initialized", False) + env_mode = os.getenv("ENV_MODE", "dev") + + if env_mode == "dev": + return # dev mode: no auth needed, even before initialisation + + if not initialized: + # prod + not initialised: only localhost may configure + client_ip = request.client.host + if client_ip not in ("127.0.0.1", "::1"): + raise HTTPException( + status_code=403, + detail="初次设置仅允许从本机访问", + ) + return + + # prod → delegate to Keycloak middleware (Phase 3) + from app.middleware.auth import require_admin + await require_admin(request, config) + + +# --------------------------------------------------------------------------- +# Routes +# --------------------------------------------------------------------------- +@router.get("/status") +async def setup_status(): + """Return current system initialisation state, including KC info for OIDC.""" + config = ConfigService.load() + env_mode = os.getenv("ENV_MODE", "dev") + + result = { + "initialized": config.get("initialized", False), + "env_mode": env_mode, + "db_configured": ConfigService.is_db_configured(), + } + + # Include Keycloak info so the frontend can build OIDC config + kc = config.get("keycloak", {}) + if env_mode == "prod" and kc.get("host") and kc.get("realm"): + result["keycloak"] = { + "authority": f"{kc['host']}/realms/{kc['realm']}", + "client_id": kc.get("client_id", ""), + } + + return result + + +@router.get("/config", dependencies=[Depends(setup_guard)]) +async def get_config(): + """Return full config with passwords replaced by placeholder.""" + config = ConfigService.load() + if "database" in config and config["database"].get("password"): + config["database"]["password"] = PASSWORD_PLACEHOLDER + return config + + +@router.put("/config", dependencies=[Depends(setup_guard)]) +async def update_config(payload: FullConfig): + """Merge submitted config sections into the YAML file.""" + config = ConfigService.load() + + if payload.database is not None: + db_dict = payload.database.model_dump() + # If password is the placeholder, keep the existing real password + if db_dict.get("password") == PASSWORD_PLACEHOLDER: + db_dict["password"] = config.get("database", {}).get("password", "") + config["database"] = db_dict + if payload.keycloak is not None: + config["keycloak"] = payload.keycloak.model_dump() + if payload.tls is not None: + config["tls"] = payload.tls.model_dump() + + ConfigService.save(config) + return {"message": "配置已保存"} + + +@router.post("/test-db", dependencies=[Depends(setup_guard)]) +async def test_db_connection(db_config: DatabaseConfig): + """Test a database connection with the provided parameters (no save).""" + from sqlalchemy import create_engine, text + + password = db_config.password + # If password is the placeholder, use the real password from config + if password == PASSWORD_PLACEHOLDER: + password = ConfigService.load().get("database", {}).get("password", "") + + url = ( + f"mysql+pymysql://{quote_plus(db_config.user)}:{quote_plus(password)}" + f"@{db_config.host}:{db_config.port}/{db_config.database}" + ) + try: + engine = create_engine(url, pool_pre_ping=True) + with engine.connect() as conn: + conn.execute(text("SELECT 1")) + engine.dispose() + return {"success": True, "message": "数据库连接成功"} + except Exception as e: + return {"success": False, "message": str(e)} + + +@router.post("/test-keycloak", dependencies=[Depends(setup_guard)]) +async def test_keycloak(kc_config: KeycloakConfig): + """Test Keycloak connectivity by fetching the OIDC discovery document.""" + import httpx + + well_known = ( + f"{kc_config.host}/realms/{kc_config.realm}" + f"/.well-known/openid-configuration" + ) + try: + async with httpx.AsyncClient(timeout=10) as client: + resp = await client.get(well_known) + if resp.status_code == 200: + return {"success": True, "message": "Keycloak 连通正常"} + return {"success": False, "message": f"HTTP {resp.status_code}"} + except Exception as e: + return {"success": False, "message": str(e)} + + +@router.post("/initialize", dependencies=[Depends(setup_guard)]) +async def initialize(): + """Mark system as initialised and reload DB connection.""" + config = ConfigService.load() + + if not ConfigService.is_db_configured(): + raise HTTPException(status_code=400, detail="请先配置数据库连接") + + # Reload DB engine so business routes can start working + from app.db_models import reload_db_connection + from app.storage.database import init_db + + reload_db_connection() + init_db() + + config["initialized"] = True + ConfigService.save(config) + + return {"message": "系统初始化完成", "initialized": True} diff --git a/app.py b/app.py new file mode 100644 index 0000000..a54c7af --- /dev/null +++ b/app.py @@ -0,0 +1,59 @@ +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from starlette.responses import JSONResponse +import uvicorn + +from storage.database import init_db +from exceptions import ServiceNotConfiguredError +from middleware.config_guard import ConfigGuardMiddleware +from api import debates_router, api_keys_router, models_router, setup_router + +app = FastAPI( + title="Dialectica - Multi-Model Debate Framework", + description="A framework for structured debates between multiple language models", + version="0.1.0" +) + +# Add CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # Allow all origins for development + allow_credentials=True, + allow_methods=["*"], # Allow all methods + allow_headers=["*"], # Allow all headers + # Add exposed headers to allow frontend to access response headers + expose_headers=["Access-Control-Allow-Origin", "Access-Control-Allow-Credentials"] +) + +# Config guard: return 503 for business routes when DB not configured +app.add_middleware(ConfigGuardMiddleware) + + +@app.exception_handler(ServiceNotConfiguredError) +async def not_configured_handler(request, exc): + return JSONResponse( + status_code=503, + content={"error_code": "SERVICE_NOT_CONFIGURED", "detail": str(exc)}, + ) + + +@app.on_event("startup") +def startup_event(): + """Initialize database on startup (skipped if not configured).""" + init_db() + + +# Register routers +app.include_router(debates_router) +app.include_router(api_keys_router) +app.include_router(models_router) +app.include_router(setup_router) + + +@app.get("/") +async def root(): + return {"message": "Welcome to Dialectica - Multi-Model Debate Framework"} + + +if __name__ == "__main__": + uvicorn.run(app, host="0.0.0.0", port=8020) diff --git a/db_models/__init__.py b/db_models/__init__.py new file mode 100644 index 0000000..c05a4b6 --- /dev/null +++ b/db_models/__init__.py @@ -0,0 +1,79 @@ +from sqlalchemy import Column, Integer, String, Boolean, DateTime, Text +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker +from datetime import datetime + +from exceptions import ServiceNotConfiguredError +from services.config_service import ConfigService + +Base = declarative_base() + + +class ApiKey(Base): + __tablename__ = "api_keys" + + id = Column(Integer, primary_key=True, index=True) + provider = Column(String(50), unique=True, index=True, nullable=False) + api_key_encrypted = Column(Text, nullable=False) # Encrypted API key + created_at = Column(DateTime, default=datetime.utcnow) + updated_at = Column(DateTime, default=datetime.utcnow) + + +class ModelConfig(Base): + __tablename__ = "model_configs" + + id = Column(Integer, primary_key=True, index=True) + provider = Column(String(50), nullable=False) + model_name = Column(String(100), nullable=False) + display_name = Column(String(100)) # Optional display name + is_active = Column(Boolean, default=True) + created_at = Column(DateTime, default=datetime.utcnow) + updated_at = Column(DateTime, default=datetime.utcnow) + + __table_args__ = {'mysql_charset': 'utf8mb4'} + + +# --------------------------------------------------------------------------- +# Lazy engine / session factory — only created when first requested +# --------------------------------------------------------------------------- +_engine = None +_SessionLocal = None + + +def _get_engine(): + from sqlalchemy import create_engine + global _engine + if _engine is None: + db_url = ConfigService.get_database_url() + if not db_url: + raise ServiceNotConfiguredError("数据库未配置") + _engine = create_engine(db_url) + return _engine + + +def _get_session_factory(): + global _SessionLocal + if _SessionLocal is None: + _SessionLocal = sessionmaker( + autocommit=False, autoflush=False, bind=_get_engine() + ) + return _SessionLocal + + +def reload_db_connection(): + """Dispose current engine and reset so next call rebuilds from config.""" + global _engine, _SessionLocal + if _engine is not None: + _engine.dispose() + _engine = None + _SessionLocal = None + + +def get_db(): + """FastAPI dependency that yields a DB session.""" + session_factory = _get_session_factory() + db = session_factory() + try: + yield db + finally: + db.close() diff --git a/exceptions/__init__.py b/exceptions/__init__.py new file mode 100644 index 0000000..4a171f7 --- /dev/null +++ b/exceptions/__init__.py @@ -0,0 +1,3 @@ +class ServiceNotConfiguredError(Exception): + """Raised when a required service (database, etc.) is not configured.""" + pass diff --git a/middleware/__init__.py b/middleware/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/middleware/auth.py b/middleware/auth.py new file mode 100644 index 0000000..e1b0cc4 --- /dev/null +++ b/middleware/auth.py @@ -0,0 +1,114 @@ +"""Keycloak JWT authentication middleware.""" + +import os +from fastapi import HTTPException, Request +from jose import jwt, JWTError, jwk +from jose.utils import base64url_decode +import httpx + +# Cache JWKS per (host, realm) to avoid fetching on every request +_jwks_cache: dict[str, dict] = {} + + +async def _get_jwks(kc_host: str, realm: str) -> dict: + cache_key = f"{kc_host}/{realm}" + if cache_key in _jwks_cache: + return _jwks_cache[cache_key] + + url = f"{kc_host}/realms/{realm}/protocol/openid-connect/certs" + async with httpx.AsyncClient(timeout=10) as client: + resp = await client.get(url) + if resp.status_code != 200: + raise HTTPException( + status_code=502, + detail=f"无法获取 Keycloak JWKS: HTTP {resp.status_code}", + ) + data = resp.json() + _jwks_cache[cache_key] = data + return data + + +def _find_rsa_key(jwks: dict, token: str) -> dict | None: + """Find the matching RSA key from JWKS for the token's kid.""" + unverified_header = jwt.get_unverified_header(token) + kid = unverified_header.get("kid") + for key in jwks.get("keys", []): + if key.get("kid") == kid: + return key + return None + + +async def verify_token(request: Request, kc_host: str, realm: str) -> dict: + """Extract and verify the Bearer JWT from the Authorization header. + + Returns the decoded payload on success. + Raises HTTPException(401) on missing/invalid token. + """ + auth_header = request.headers.get("Authorization", "") + if not auth_header.startswith("Bearer "): + raise HTTPException(status_code=401, detail="缺少 Authorization Bearer token") + + token = auth_header[7:] + + jwks = await _get_jwks(kc_host, realm) + rsa_key = _find_rsa_key(jwks, token) + if rsa_key is None: + # Clear cache in case keys rotated + _jwks_cache.pop(f"{kc_host}/{realm}", None) + jwks = await _get_jwks(kc_host, realm) + rsa_key = _find_rsa_key(jwks, token) + if rsa_key is None: + raise HTTPException(status_code=401, detail="无法匹配 JWT 签名密钥") + + try: + payload = jwt.decode( + token, + rsa_key, + algorithms=["RS256"], + options={"verify_aud": False}, # Keycloak audience varies by client + ) + return payload + except JWTError as e: + raise HTTPException(status_code=401, detail=f"JWT 验证失败: {e}") + + +async def require_auth(request: Request): + """Verify Bearer JWT for write endpoints. + + Dev mode: passthrough (no auth required). + Prod mode: validates JWT via Keycloak JWKS. + """ + if os.getenv("ENV_MODE", "dev") == "dev": + return None + + from app.services.config_service import ConfigService + config = ConfigService.load() + kc = config.get("keycloak", {}) + if not kc.get("host"): + return None # KC not configured – allow access + + return await verify_token(request, kc["host"], kc.get("realm", "")) + + +async def require_admin(request: Request, config: dict): + """Verify the request carries a valid Keycloak JWT with admin role. + + Raises HTTPException(401/403) on failure. + """ + kc = config.get("keycloak", {}) + kc_host = kc.get("host") + realm = kc.get("realm") + + if not kc_host or not realm: + raise HTTPException( + status_code=503, + detail="Keycloak 未配置,无法进行鉴权", + ) + + payload = await verify_token(request, kc_host, realm) + + roles = payload.get("realm_access", {}).get("roles", []) + if "admin" not in roles: + raise HTTPException(status_code=403, detail="需要 admin 角色") + + return payload diff --git a/middleware/config_guard.py b/middleware/config_guard.py new file mode 100644 index 0000000..fb666d9 --- /dev/null +++ b/middleware/config_guard.py @@ -0,0 +1,32 @@ +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.responses import JSONResponse + +from services.config_service import ConfigService + +# Paths that are always accessible, even when DB is not configured +_ALLOWED_PREFIXES = ("/api/setup", "/docs", "/openapi", "/redoc") + + +class ConfigGuardMiddleware(BaseHTTPMiddleware): + """Return 503 for all business routes when the database is not configured.""" + + async def dispatch(self, request, call_next): + path = request.url.path + + # Always allow: setup routes, root, docs, OPTIONS (CORS preflight) + if path == "/" or request.method == "OPTIONS": + return await call_next(request) + for prefix in _ALLOWED_PREFIXES: + if path.startswith(prefix): + return await call_next(request) + + if not ConfigService.is_db_configured(): + return JSONResponse( + status_code=503, + content={ + "error_code": "SERVICE_NOT_CONFIGURED", + "detail": "数据库未配置,请先完成系统初始化", + }, + ) + + return await call_next(request) diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/debate.py b/models/debate.py new file mode 100644 index 0000000..cf30ee3 --- /dev/null +++ b/models/debate.py @@ -0,0 +1,92 @@ +from pydantic import BaseModel +from typing import List, Optional +from enum import Enum +from datetime import datetime + + +class ModelProvider(str, Enum): + OPENAI = "openai" + DEEPSEEK = "deepseek" + QWEN = "qwen" + CLAUDE = "claude" + + +class DebateStance(str, Enum): + PRO = "pro" + CON = "con" + + +class SearchResult(BaseModel): + title: str + url: str + snippet: str + score: Optional[float] = None + + +class SearchEvidence(BaseModel): + query: str + results: List[SearchResult] + mode: str # "auto", "tool", "both" + + +class DebateRound(BaseModel): + round_number: int + speaker: str # Model identifier + stance: DebateStance + content: str + timestamp: datetime + token_count: Optional[int] = None + search_evidence: Optional[SearchEvidence] = None + + +class DebateParticipant(BaseModel): + model_config = {"protected_namespaces": ()} + + model_identifier: str + provider: ModelProvider + stance: DebateStance + api_key: Optional[str] = None + + +class DebateConstraints(BaseModel): + max_rounds: int = 5 + max_tokens_per_turn: int = 500 + max_total_tokens: Optional[int] = None + forbid_repetition: bool = True + must_respond_to_opponent: bool = True + web_search_enabled: bool = False + web_search_mode: str = "auto" # "auto", "tool", "both" + + +class DebateRequest(BaseModel): + topic: str + participants: List[DebateParticipant] + constraints: DebateConstraints + custom_system_prompt: Optional[str] = None + + +class EvidenceReference(BaseModel): + round_number: int + speaker: str + stance: DebateStance + + +class EvidenceEntry(BaseModel): + title: str + url: str + snippet: str + score: Optional[float] = None + references: List[EvidenceReference] + + +class DebateSession(BaseModel): + session_id: str + topic: str + participants: List[DebateParticipant] + constraints: DebateConstraints + rounds: List[DebateRound] + status: str # "active", "completed", "terminated" + created_at: datetime + completed_at: Optional[datetime] = None + summary: Optional[str] = None + evidence_library: List[EvidenceEntry] = [] \ No newline at end of file diff --git a/orchestrator/__init__.py b/orchestrator/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/orchestrator/debate_orchestrator.py b/orchestrator/debate_orchestrator.py new file mode 100644 index 0000000..61ecb9c --- /dev/null +++ b/orchestrator/debate_orchestrator.py @@ -0,0 +1,358 @@ +import asyncio +import uuid +from typing import Dict, List, Optional +from datetime import datetime +from sqlalchemy.orm import Session + +from models.debate import ( + DebateRequest, DebateSession, DebateRound, + DebateStance, DebateParticipant, SearchEvidence, SearchResult, + EvidenceEntry, EvidenceReference +) +from providers.provider_factory import ProviderFactory +from storage.session_manager import SessionManager +from utils.summarizer import summarize_debate +from services.search_service import SearchService +from services.api_key_service import ApiKeyService + + +class DebateOrchestrator: + """ + Orchestrates the debate between multiple language models + """ + + def __init__(self, db: Session): + self.db = db + self.session_manager = SessionManager() + self.provider_factory = ProviderFactory() + + async def create_session(self, debate_request: DebateRequest) -> str: + """ + Create a new debate session + """ + session_id = str(uuid.uuid4()) + + session = DebateSession( + session_id=session_id, + topic=debate_request.topic, + participants=debate_request.participants, + constraints=debate_request.constraints, + rounds=[], + status="active", + created_at=datetime.now() + ) + + await self.session_manager.save_session(self.db, session) + return session_id + + def _get_search_service(self) -> Optional[SearchService]: + """ + Get a SearchService instance if Tavily API key is available. + """ + tavily_key = ApiKeyService.get_api_key(self.db, "tavily") + if tavily_key: + return SearchService(api_key=tavily_key) + return None + + async def run_debate(self, session_id: str) -> DebateSession: + """ + Run the complete debate process + """ + session = await self.session_manager.get_session(self.db, session_id) + if not session: + raise ValueError(f"Session {session_id} not found") + + # Initialize providers for each participant + providers = {} + for participant in session.participants: + provider = self.provider_factory.create_provider( + self.db, + participant.provider, + participant.api_key # This can be None, and the provider will fetch from DB + ) + providers[participant.model_identifier] = provider + + # Initialize search service if web search is enabled + search_service = None + web_search_enabled = session.constraints.web_search_enabled + web_search_mode = session.constraints.web_search_mode + if web_search_enabled: + search_service = self._get_search_service() + if not search_service: + print("Warning: Web search enabled but no Tavily API key found. Disabling search.") + web_search_enabled = False + + # Run the debate rounds + for round_num in range(session.constraints.max_rounds): + if session.status != "active": + break + + # Alternate between participants + current_participant = session.participants[round_num % len(session.participants)] + provider = providers[current_participant.model_identifier] + + # Perform automatic search if enabled + search_evidence = None + if web_search_enabled and web_search_mode in ("auto", "both"): + search_evidence = self._perform_automatic_search( + search_service, session, round_num + ) + + # Prepare context for the current turn (with search results if available) + context = self._prepare_context(session, current_participant.stance, search_evidence) + + # Determine if we should use tool calling for this round + use_tool_calling = ( + web_search_enabled + and web_search_mode in ("tool", "both") + and provider.supports_tools() + ) + + if use_tool_calling: + response, tool_evidence = await self._handle_tool_calls( + provider, current_participant.model_identifier, + context, search_service + ) + # Merge tool-based evidence with auto evidence + if tool_evidence: + if search_evidence: + search_evidence.results.extend(tool_evidence.results) + search_evidence.query += f" | {tool_evidence.query}" + search_evidence.mode = "both" + else: + search_evidence = tool_evidence + else: + response = await provider.generate_response( + model=current_participant.model_identifier, + prompt=context + ) + + # Clean the response to remove any echoed prompt/meta text + response = self._clean_response(response) + + # Create a new round + round_data = DebateRound( + round_number=round_num + 1, + speaker=current_participant.model_identifier, + stance=current_participant.stance, + content=response, + timestamp=datetime.now(), + token_count=len(response.split()), # Approximate token count + search_evidence=search_evidence + ) + + session.rounds.append(round_data) + + # Update evidence library with search results from this round + if round_data.search_evidence: + self._update_evidence_library(session, round_data) + + # Update session in storage + await self.session_manager.update_session(self.db, session) + + # Small delay between rounds to simulate realistic interaction + await asyncio.sleep(1) + + # Generate summary after all rounds are complete + summary = await summarize_debate(session) + session.summary = summary + session.status = "completed" + session.completed_at = datetime.now() + + await self.session_manager.update_session(self.db, session) + return session + + def _perform_automatic_search( + self, search_service: SearchService, session: DebateSession, round_num: int + ) -> Optional[SearchEvidence]: + """ + Perform an automatic web search based on the topic and last opponent argument. + """ + last_opponent_arg = None + if session.rounds: + last_opponent_arg = session.rounds[-1].content + + query = SearchService.generate_search_query(session.topic, last_opponent_arg) + results = search_service.search(query, max_results=3) + + if results: + return SearchEvidence( + query=query, + results=results, + mode="auto" + ) + return None + + async def _handle_tool_calls( + self, provider, model: str, context: str, search_service: SearchService + ) -> tuple: + """ + Handle tool calling flow: send prompt with tools, execute any tool calls, + then re-prompt with results. Max 2 tool call iterations. + Returns (final_response_text, SearchEvidence_or_None). + """ + tools = [SearchService.get_tool_definition()] + all_search_results = [] + all_queries = [] + + text, tool_calls = await provider.generate_response_with_tools( + model=model, + prompt=context, + tools=tools, + max_tokens=500 + ) + + # If no tool calls, return the text response directly + if not tool_calls: + return text, None + + # Process up to 2 rounds of tool calls + for iteration in range(2): + if not tool_calls: + break + + # Execute each tool call + tool_results_text = [] + for tc in tool_calls: + if tc["name"] == "web_search": + query = tc["arguments"].get("query", "") + all_queries.append(query) + results = search_service.search(query, max_results=3) + all_search_results.extend(results) + + # Format results for the model + evidence = SearchEvidence(query=query, results=results, mode="tool") + tool_results_text.append(SearchService.format_results_for_context(evidence)) + + # Re-prompt the model with tool results + augmented_context = context + "\n" + "\n".join(tool_results_text) + augmented_context += "\n请基于以上搜索结果和辩论历史,给出你的论点。" + + text, tool_calls = await provider.generate_response_with_tools( + model=model, + prompt=augmented_context, + tools=tools, + max_tokens=500 + ) + + # Build combined evidence + evidence = None + if all_search_results: + evidence = SearchEvidence( + query=" | ".join(all_queries), + results=all_search_results, + mode="tool" + ) + + # If we still got no text (model keeps calling tools), fall back + if not text: + text = await provider.generate_response(model=model, prompt=context, max_tokens=500) + + return text, evidence + + def _update_evidence_library(self, session: DebateSession, round_data: DebateRound): + """ + Merge search results from a round into the session's evidence library, deduplicating by URL. + """ + ref = EvidenceReference( + round_number=round_data.round_number, + speaker=round_data.speaker, + stance=round_data.stance + ) + + url_index = {entry.url: i for i, entry in enumerate(session.evidence_library)} + + for result in round_data.search_evidence.results: + if result.url in url_index: + entry = session.evidence_library[url_index[result.url]] + # Avoid duplicate references (same round + speaker) + if not any( + r.round_number == ref.round_number and r.speaker == ref.speaker + for r in entry.references + ): + entry.references.append(ref) + else: + new_entry = EvidenceEntry( + title=result.title, + url=result.url, + snippet=result.snippet, + score=result.score, + references=[ref] + ) + session.evidence_library.append(new_entry) + url_index[result.url] = len(session.evidence_library) - 1 + + def _prepare_context( + self, session: DebateSession, current_stance: DebateStance, + search_evidence: Optional[SearchEvidence] = None + ) -> str: + """ + Prepare the context/prompt for the current model turn + """ + # Determine the stance of the current speaker + if current_stance == DebateStance.PRO: + position_desc = "正方(支持方)" + opposing_desc = "反方(反对方)" + else: + position_desc = "反方(反对方)" + opposing_desc = "正方(支持方)" + + # Build the context with previous rounds + context_parts = [ + f"辩论主题: {session.topic}", + f"你的立场: {position_desc}", + "辩论规则:", + "- 必须回应对方上一轮的核心论点", + "- 不得重复自己已提出的观点", + "- 输出长度限制在合理范围内", + "\n历史辩论记录:" + ] + + for round_data in session.rounds: + stance_text = "正方" if round_data.stance == DebateStance.PRO else "反方" + context_parts.append(f"第{round_data.round_number}轮 - {stance_text}: {round_data.content}") + + # Inject search results if available + if search_evidence: + context_parts.append(SearchService.format_results_for_context(search_evidence)) + + context_parts.append(f"\n现在轮到你 ({position_desc}) 发言,请基于以上内容进行回应。注意:直接给出你的论点内容,不要重复上述提示词、辩论规则或历史记录。") + + return "\n".join(context_parts) + + def _clean_response(self, response: str) -> str: + """ + Clean the model response to remove any echoed prompt/meta text + """ + import re + + # Remove common prompt echoes and meta prefixes + patterns_to_remove = [ + r'^第\d+轮\s*[-::]\s*(正方|反方)\s*[::]?\s*', # 第X轮 - 正方/反方: + r'^(正方|反方)\s*[((][^))]*[))]\s*[::]?\s*', # 正方(支持方): + r'^(正方|反方)\s*[::]\s*', # 正方: or 反方: + r'^我的立场\s*[::]\s*', # 我的立场: + r'^回应\s*[::]\s*', # 回应: + r'^辩论发言\s*[::]\s*', # 辩论发言: + ] + + cleaned = response.strip() + for pattern in patterns_to_remove: + cleaned = re.sub(pattern, '', cleaned, flags=re.MULTILINE) + + return cleaned.strip() + + async def get_session_status(self, session_id: str) -> Optional[DebateSession]: + """ + Get the current status of a debate session + """ + return await self.session_manager.get_session(self.db, session_id) + + async def terminate_session(self, session_id: str): + """ + Terminate a debate session prematurely + """ + session = await self.session_manager.get_session(self.db, session_id) + if session: + session.status = "terminated" + await self.session_manager.update_session(self.db, session) diff --git a/providers/__init__.py b/providers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/providers/base_provider.py b/providers/base_provider.py new file mode 100644 index 0000000..05fcb0c --- /dev/null +++ b/providers/base_provider.py @@ -0,0 +1,73 @@ +from abc import ABC, abstractmethod +from typing import Optional, List, Dict, Any, Tuple +from sqlalchemy.orm import Session + + +class LLMProvider(ABC): + """ + Abstract base class for LLM providers + """ + + @abstractmethod + def __init__(self, db: Session, api_key: Optional[str] = None): + """ + Initialize the provider with a database session and optional API key + """ + pass + + @abstractmethod + async def generate_response(self, model: str, prompt: str, max_tokens: Optional[int] = None) -> str: + """ + Generate a response from the LLM + """ + pass + + def supports_tools(self) -> bool: + """ + Whether this provider supports function/tool calling. + Override in subclasses that support it. + """ + return False + + async def generate_response_with_tools( + self, + model: str, + prompt: str, + tools: List[Dict[str, Any]], + max_tokens: Optional[int] = None + ) -> Tuple[str, List[Dict[str, Any]]]: + """ + Generate a response with tool definitions available. + Returns (text_content, tool_calls) where tool_calls is a list of + dicts with keys: name, arguments (dict). + If no tools are called, tool_calls is empty and text_content has the response. + Default implementation falls back to regular generation. + """ + text = await self.generate_response(model, prompt, max_tokens) + return text, [] + + +class ProviderFactory: + """ + Factory class to create provider instances + """ + + @staticmethod + def create_provider(db: Session, provider_type: str, api_key: Optional[str] = None): + """ + Create a provider instance based on the type + """ + if provider_type.value == "openai": + from app.providers.openai_provider import OpenAIProvider + return OpenAIProvider(db, api_key) + elif provider_type.value == "claude": + from app.providers.claude_provider import ClaudeProvider + return ClaudeProvider(db, api_key) + elif provider_type.value == "qwen": + from app.providers.qwen_provider import QwenProvider + return QwenProvider(db, api_key) + elif provider_type.value == "deepseek": + from app.providers.deepseek_provider import DeepSeekProvider + return DeepSeekProvider(db, api_key) + else: + raise ValueError(f"Unsupported provider type: {provider_type}") \ No newline at end of file diff --git a/providers/claude_provider.py b/providers/claude_provider.py new file mode 100644 index 0000000..0795038 --- /dev/null +++ b/providers/claude_provider.py @@ -0,0 +1,85 @@ +import anthropic +from typing import Optional, List, Dict, Any, Tuple +from sqlalchemy.orm import Session + +from providers.base_provider import LLMProvider +from services.api_key_service import ApiKeyService + +SYSTEM_PROMPT = "你正在参与一场结构化辩论。请按照用户消息中的规则进行辩论,直接给出你的论点,不要重复提示词或历史记录。" + + +class ClaudeProvider(LLMProvider): + """ + Anthropic Claude API provider implementation + """ + + def __init__(self, db: Session, api_key: Optional[str] = None): + if not api_key: + api_key = ApiKeyService.get_api_key(db, "claude") + + if api_key: + self.client = anthropic.AsyncAnthropic(api_key=api_key) + else: + raise ValueError("Claude API key not found in database or provided") + + def supports_tools(self) -> bool: + return True + + async def generate_response(self, model: str, prompt: str, max_tokens: Optional[int] = None) -> str: + try: + response = await self.client.messages.create( + model=model, + max_tokens=max_tokens or 500, + temperature=0.7, + system=SYSTEM_PROMPT, + messages=[ + {"role": "user", "content": prompt} + ] + ) + return response.content[0].text + except Exception as e: + raise Exception(f"Error calling Claude API: {str(e)}") + + async def generate_response_with_tools( + self, + model: str, + prompt: str, + tools: List[Dict[str, Any]], + max_tokens: Optional[int] = None + ) -> Tuple[str, List[Dict[str, Any]]]: + try: + # Convert OpenAI-format tools to Anthropic format + anthropic_tools = [] + for tool in tools: + func = tool.get("function", tool) + anthropic_tools.append({ + "name": func["name"], + "description": func.get("description", ""), + "input_schema": func.get("parameters", func.get("input_schema", {})) + }) + + response = await self.client.messages.create( + model=model, + max_tokens=max_tokens or 500, + temperature=0.7, + system=SYSTEM_PROMPT, + messages=[ + {"role": "user", "content": prompt} + ], + tools=anthropic_tools + ) + + text_content = "" + tool_calls = [] + for block in response.content: + if block.type == "text": + text_content += block.text + elif block.type == "tool_use": + tool_calls.append({ + "name": block.name, + "arguments": block.input + }) + + return text_content.strip(), tool_calls + except Exception as e: + raise Exception(f"Error calling Claude API with tools: {str(e)}") \ No newline at end of file diff --git a/providers/deepseek_provider.py b/providers/deepseek_provider.py new file mode 100644 index 0000000..dee3972 --- /dev/null +++ b/providers/deepseek_provider.py @@ -0,0 +1,64 @@ +from typing import Optional +from sqlalchemy.orm import Session +from openai import AsyncOpenAI + +from providers.base_provider import LLMProvider +from services.api_key_service import ApiKeyService + + +class DeepSeekProvider(LLMProvider): + """ + DeepSeek API provider implementation using OpenAI-compatible API + """ + + def __init__(self, db: Session, api_key: Optional[str] = None): + if not api_key: + api_key = ApiKeyService.get_api_key(db, "deepseek") + + if not api_key: + raise ValueError("DeepSeek API key not found in database or provided") + + self.client = AsyncOpenAI( + api_key=api_key, + base_url="https://api.deepseek.com" + ) + + def supports_tools(self) -> bool: + return False + + async def generate_response(self, model: str, prompt: str, max_tokens: Optional[int] = None) -> str: + try: + is_reasoner = "reasoner" in model or "r1" in model.lower() + + messages = [{"role": "user", "content": prompt}] + # deepseek-reasoner 不支持 system message,把指令放进 user message + if not is_reasoner: + messages.insert(0, { + "role": "system", + "content": "你正在参与一场结构化辩论。请按照用户消息中的规则进行辩论,直接给出你的论点,不要重复提示词或历史记录。" + }) + + kwargs = { + "model": model, + "messages": messages, + } + # reasoner 模型不支持 max_tokens,使用 max_completion_tokens + if is_reasoner: + kwargs["max_completion_tokens"] = max_tokens or 4096 + else: + kwargs["max_tokens"] = max_tokens or 500 + + response = await self.client.chat.completions.create(**kwargs) + + message = response.choices[0].message + content = message.content or "" + + # deepseek-reasoner 的主要内容可能在 reasoning_content 中 + if not content.strip() and is_reasoner: + reasoning = getattr(message, "reasoning_content", None) + if reasoning: + content = reasoning + + return content.strip() + except Exception as e: + raise Exception(f"Error calling DeepSeek API: {str(e)}") \ No newline at end of file diff --git a/providers/openai_provider.py b/providers/openai_provider.py new file mode 100644 index 0000000..a7c0bb8 --- /dev/null +++ b/providers/openai_provider.py @@ -0,0 +1,72 @@ +import json +import openai +from typing import Optional, List, Dict, Any, Tuple +from sqlalchemy.orm import Session + +from providers.base_provider import LLMProvider +from services.api_key_service import ApiKeyService + +SYSTEM_PROMPT = "你正在参与一场结构化辩论。请按照用户消息中的规则进行辩论,直接给出你的论点,不要重复提示词或历史记录。" + + +class OpenAIProvider(LLMProvider): + """ + OpenAI API provider implementation + """ + + def __init__(self, db: Session, api_key: Optional[str] = None): + if not api_key: + api_key = ApiKeyService.get_api_key(db, "openai") + + if api_key: + self.client = openai.AsyncOpenAI(api_key=api_key) + else: + raise ValueError("OpenAI API key not found in database or provided") + + def supports_tools(self) -> bool: + return True + + async def generate_response(self, model: str, prompt: str, max_tokens: Optional[int] = None) -> str: + try: + response = await self.client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": prompt} + ], + max_tokens=max_tokens or 500 + ) + return response.choices[0].message.content.strip() + except Exception as e: + raise Exception(f"Error calling OpenAI API: {str(e)}") + + async def generate_response_with_tools( + self, + model: str, + prompt: str, + tools: List[Dict[str, Any]], + max_tokens: Optional[int] = None + ) -> Tuple[str, List[Dict[str, Any]]]: + try: + response = await self.client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": prompt} + ], + tools=tools, + tool_choice="auto", + max_tokens=max_tokens or 500 + ) + message = response.choices[0].message + text_content = message.content or "" + tool_calls = [] + if message.tool_calls: + for tc in message.tool_calls: + tool_calls.append({ + "name": tc.function.name, + "arguments": json.loads(tc.function.arguments) + }) + return text_content.strip(), tool_calls + except Exception as e: + raise Exception(f"Error calling OpenAI API with tools: {str(e)}") \ No newline at end of file diff --git a/providers/provider_factory.py b/providers/provider_factory.py new file mode 100644 index 0000000..01aa67c --- /dev/null +++ b/providers/provider_factory.py @@ -0,0 +1,49 @@ +from abc import ABC, abstractmethod +from typing import Optional +from sqlalchemy.orm import Session + + +class LLMProvider(ABC): + """ + Abstract base class for LLM providers + """ + + @abstractmethod + def __init__(self, db: Session, api_key: Optional[str] = None): + """ + Initialize the provider with a database session and optional API key + """ + pass + + @abstractmethod + async def generate_response(self, model: str, prompt: str, max_tokens: Optional[int] = None) -> str: + """ + Generate a response from the LLM + """ + pass + + +class ProviderFactory: + """ + Factory class to create provider instances + """ + + @staticmethod + def create_provider(db: Session, provider_type: str, api_key: Optional[str] = None): + """ + Create a provider instance based on the type + """ + if provider_type.value == "openai": + from providers.openai_provider import OpenAIProvider + return OpenAIProvider(db, api_key) + elif provider_type.value == "claude": + from providers.claude_provider import ClaudeProvider + return ClaudeProvider(db, api_key) + elif provider_type.value == "qwen": + from providers.qwen_provider import QwenProvider + return QwenProvider(db, api_key) + elif provider_type.value == "deepseek": + from providers.deepseek_provider import DeepSeekProvider + return DeepSeekProvider(db, api_key) + else: + raise ValueError(f"Unsupported provider type: {provider_type}") \ No newline at end of file diff --git a/providers/qwen_provider.py b/providers/qwen_provider.py new file mode 100644 index 0000000..9b8eb42 --- /dev/null +++ b/providers/qwen_provider.py @@ -0,0 +1,75 @@ +import json +from typing import Optional, List, Dict, Any, Tuple +from sqlalchemy.orm import Session +from openai import AsyncOpenAI + +from providers.base_provider import LLMProvider +from services.api_key_service import ApiKeyService + +SYSTEM_PROMPT = "你正在参与一场结构化辩论。请按照用户消息中的规则进行辩论,直接给出你的论点,不要重复提示词或历史记录。" + + +class QwenProvider(LLMProvider): + """ + Qwen API provider implementation using DashScope OpenAI-compatible API + """ + + def __init__(self, db: Session, api_key: Optional[str] = None): + if not api_key: + api_key = ApiKeyService.get_api_key(db, "qwen") + + if not api_key: + raise ValueError("Qwen API key not found in database or provided") + + self.client = AsyncOpenAI( + api_key=api_key, + base_url="https://dashscope.aliyuncs.com/compatible-mode/v1" + ) + + def supports_tools(self) -> bool: + return True + + async def generate_response(self, model: str, prompt: str, max_tokens: Optional[int] = None) -> str: + try: + response = await self.client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": prompt} + ], + max_tokens=max_tokens or 500 + ) + return response.choices[0].message.content.strip() + except Exception as e: + raise Exception(f"Error calling Qwen API: {str(e)}") + + async def generate_response_with_tools( + self, + model: str, + prompt: str, + tools: List[Dict[str, Any]], + max_tokens: Optional[int] = None + ) -> Tuple[str, List[Dict[str, Any]]]: + try: + response = await self.client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": prompt} + ], + tools=tools, + tool_choice="auto", + max_tokens=max_tokens or 500 + ) + message = response.choices[0].message + text_content = message.content or "" + tool_calls = [] + if message.tool_calls: + for tc in message.tool_calls: + tool_calls.append({ + "name": tc.function.name, + "arguments": json.loads(tc.function.arguments) + }) + return text_content.strip(), tool_calls + except Exception as e: + raise Exception(f"Error calling Qwen API with tools: {str(e)}") diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..f4a8f21 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,19 @@ +fastapi==0.115.0 +uvicorn[standard]==0.32.0 +pydantic==2.9.2 +pydantic-settings==2.6.1 +sqlalchemy==2.0.35 +aiosqlite==0.20.0 +pymysql==1.1.1 +cryptography==43.0.1 +python-multipart==0.0.20 +sse-starlette==2.1.3 +openai==1.52.2 +anthropic==0.40.0 +tiktoken==0.8.0 +python-dotenv==1.0.1 +aiohttp==3.9.0 +httpx==0.27.2 +tavily-python==0.5.0 +pyyaml==6.0.2 +python-jose[cryptography]==3.3.0 diff --git a/services/__init__.py b/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/services/api_key_service.py b/services/api_key_service.py new file mode 100644 index 0000000..4a82b67 --- /dev/null +++ b/services/api_key_service.py @@ -0,0 +1,71 @@ +from cryptography.fernet import Fernet, InvalidToken +from sqlalchemy.orm import Session +from db_models import ApiKey +import os + +# Initialize the encryption key from environment or generate a new one +ENCRYPTION_KEY = os.getenv("ENCRYPTION_KEY", Fernet.generate_key().decode()) +cipher_suite = Fernet(ENCRYPTION_KEY.encode()) + + +class ApiKeyService: + """ + Service for managing API keys in the database + """ + + @staticmethod + def encrypt_api_key(api_key: str) -> str: + """ + Encrypt an API key + """ + encrypted_key = cipher_suite.encrypt(api_key.encode()) + return encrypted_key.decode() + + @staticmethod + def decrypt_api_key(encrypted_api_key: str) -> str: + """ + Decrypt an API key + """ + decrypted_key = cipher_suite.decrypt(encrypted_api_key.encode()) + return decrypted_key.decode() + + @staticmethod + def get_api_key(db: Session, provider: str) -> str: + """ + Retrieve and decrypt an API key for a provider + """ + api_key_record = db.query(ApiKey).filter(ApiKey.provider == provider).first() + if not api_key_record or not api_key_record.api_key_encrypted: + return None + try: + return ApiKeyService.decrypt_api_key(api_key_record.api_key_encrypted) + except InvalidToken: + return None + + @staticmethod + def set_api_key(db: Session, provider: str, api_key: str) -> bool: + """ + Encrypt and store an API key for a provider + """ + encrypted_key = ApiKeyService.encrypt_api_key(api_key) + + # Check if record exists + api_key_record = db.query(ApiKey).filter(ApiKey.provider == provider).first() + + if api_key_record: + # Update existing record + api_key_record.api_key_encrypted = encrypted_key + else: + # Create new record + api_key_record = ApiKey( + provider=provider, + api_key_encrypted=encrypted_key + ) + db.add(api_key_record) + + try: + db.commit() + return True + except Exception: + db.rollback() + return False \ No newline at end of file diff --git a/services/config_service.py b/services/config_service.py new file mode 100644 index 0000000..74c5e2d --- /dev/null +++ b/services/config_service.py @@ -0,0 +1,100 @@ +import os +from pathlib import Path +from urllib.parse import quote_plus + +import yaml +from cryptography.fernet import Fernet, InvalidToken + +CONFIG_PATH = Path(os.getenv("CONFIG_PATH", "/app/config/dialectica.yaml")) + +# Reuse the same encryption key used for API keys +_ENCRYPTION_KEY = os.getenv("ENCRYPTION_KEY", "") +_cipher = Fernet(_ENCRYPTION_KEY.encode()) if _ENCRYPTION_KEY else None + +# Fields that should be encrypted in the YAML file +_SECRET_FIELDS = {"password"} + + +def _encrypt(value: str) -> str: + if not _cipher or not value: + return value + return "ENC:" + _cipher.encrypt(value.encode()).decode() + + +def _decrypt(value: str) -> str: + if not _cipher or not isinstance(value, str) or not value.startswith("ENC:"): + return value + try: + return _cipher.decrypt(value[4:].encode()).decode() + except InvalidToken: + return value + + +def _encrypt_secrets(data: dict) -> dict: + """Deep-copy dict, encrypting secret fields.""" + out = {} + for k, v in data.items(): + if isinstance(v, dict): + out[k] = _encrypt_secrets(v) + elif k in _SECRET_FIELDS and isinstance(v, str) and not v.startswith("ENC:"): + out[k] = _encrypt(v) + else: + out[k] = v + return out + + +def _decrypt_secrets(data: dict) -> dict: + """Deep-copy dict, decrypting secret fields.""" + out = {} + for k, v in data.items(): + if isinstance(v, dict): + out[k] = _decrypt_secrets(v) + elif k in _SECRET_FIELDS and isinstance(v, str): + out[k] = _decrypt(v) + else: + out[k] = v + return out + + +class ConfigService: + """Read / write config/dialectica.yaml.""" + + @staticmethod + def load() -> dict: + """Load config, returning decrypted values. Empty dict if file missing.""" + if not CONFIG_PATH.exists(): + return {} + with open(CONFIG_PATH) as f: + raw = yaml.safe_load(f) or {} + return _decrypt_secrets(raw) + + @staticmethod + def save(config: dict): + """Save config, encrypting secret fields.""" + CONFIG_PATH.parent.mkdir(parents=True, exist_ok=True) + encrypted = _encrypt_secrets(config) + with open(CONFIG_PATH, "w") as f: + yaml.dump(encrypted, f, default_flow_style=False, allow_unicode=True) + + @staticmethod + def is_db_configured() -> bool: + config = ConfigService.load() + db = config.get("database", {}) + return bool(db.get("host") and db.get("database")) + + @staticmethod + def get_database_url() -> str | None: + config = ConfigService.load() + db = config.get("database", {}) + if not (db.get("host") and db.get("database")): + return None + user = db.get("user", "root") + password = db.get("password", "") + host = db["host"] + port = db.get("port", 3306) + database = db["database"] + return f"mysql+pymysql://{quote_plus(user)}:{quote_plus(password)}@{host}:{port}/{database}" + + @staticmethod + def is_initialized() -> bool: + return ConfigService.load().get("initialized", False) diff --git a/services/search_service.py b/services/search_service.py new file mode 100644 index 0000000..2665128 --- /dev/null +++ b/services/search_service.py @@ -0,0 +1,104 @@ +from typing import List, Optional +from models.debate import SearchResult, SearchEvidence + + +class SearchService: + """ + Tavily web search wrapper for debate research. + """ + + def __init__(self, api_key: str): + from tavily import TavilyClient + self.client = TavilyClient(api_key=api_key) + + def search(self, query: str, max_results: int = 5) -> List[SearchResult]: + """ + Perform a web search using Tavily and return structured results. + """ + try: + response = self.client.search( + query=query, + max_results=max_results, + search_depth="basic" + ) + results = [] + for item in response.get("results", []): + results.append(SearchResult( + title=item.get("title", ""), + url=item.get("url", ""), + snippet=item.get("content", "")[:500], # Truncate to avoid token bloat + score=item.get("score") + )) + return results + except Exception as e: + print(f"Tavily search error: {e}") + return [] + + @staticmethod + def format_results_for_context(evidence: SearchEvidence) -> str: + """ + Format search results into a string suitable for injecting into the LLM context. + """ + if not evidence or not evidence.results: + return "" + lines = [f"\n[网络搜索结果] 搜索词: \"{evidence.query}\""] + for i, r in enumerate(evidence.results, 1): + lines.append(f" {i}. {r.title}") + lines.append(f" {r.snippet}") + lines.append(f" 来源: {r.url}") + lines.append("[搜索结果结束]\n") + return "\n".join(lines) + + @staticmethod + def generate_search_query(topic: str, last_opponent_argument: Optional[str] = None) -> str: + """ + Generate a search query from the debate topic and the opponent's last argument. + """ + if last_opponent_argument: + # Extract key phrases from the opponent's argument (first ~100 chars) + snippet = last_opponent_argument[:100].strip() + return f"{topic} {snippet}" + return topic + + @staticmethod + def get_tool_definition() -> dict: + """ + Return the web_search tool definition for function calling. + """ + return { + "type": "function", + "function": { + "name": "web_search", + "description": "Search the web for current information relevant to the debate topic. Use this to find facts, statistics, or recent news that support your argument.", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The search query to look up" + } + }, + "required": ["query"] + } + } + } + + @staticmethod + def get_tool_definition_anthropic() -> dict: + """ + Return the web_search tool definition in Anthropic format. + """ + return { + "name": "web_search", + "description": "Search the web for current information relevant to the debate topic. Use this to find facts, statistics, or recent news that support your argument.", + "input_schema": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The search query to look up" + } + }, + "required": ["query"] + } + } diff --git a/storage/__init__.py b/storage/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/storage/database.py b/storage/database.py new file mode 100644 index 0000000..fe97d0a --- /dev/null +++ b/storage/database.py @@ -0,0 +1,109 @@ +from sqlalchemy import Column, Integer, String, DateTime, Text +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker +from datetime import datetime +import json + +from models.debate import DebateSession +from exceptions import ServiceNotConfiguredError +from services.config_service import ConfigService + +Base = declarative_base() + + +class DebateSessionDB(Base): + __tablename__ = "debate_sessions" + + id = Column(Integer, primary_key=True, index=True) + session_id = Column(String(255), unique=True, index=True, nullable=False) + topic = Column(Text, nullable=False) + participants = Column(Text, nullable=False) # JSON string + constraints = Column(Text, nullable=False) # JSON string + rounds = Column(Text, nullable=False) # JSON string + status = Column(String(50), nullable=False) + created_at = Column(DateTime, nullable=False) + completed_at = Column(DateTime, nullable=True) + summary = Column(Text, nullable=True) + evidence_library = Column(Text, nullable=True) # JSON string + + +# --------------------------------------------------------------------------- +# Lazy engine / session factory +# --------------------------------------------------------------------------- +_engine = None +_SessionLocal = None + + +def _get_engine(): + from sqlalchemy import create_engine + global _engine + if _engine is None: + db_url = ConfigService.get_database_url() + if not db_url: + raise ServiceNotConfiguredError("数据库未配置") + _engine = create_engine(db_url) + return _engine + + +def _get_session_factory(): + global _SessionLocal + if _SessionLocal is None: + _SessionLocal = sessionmaker( + autocommit=False, autoflush=False, bind=_get_engine() + ) + return _SessionLocal + + +def init_db(): + """Create all tables if DB is configured; silently skip otherwise.""" + if not ConfigService.is_db_configured(): + print("WARNING: Database not configured, skipping table creation.") + return + try: + from db_models import Base as ApiBase + engine = _get_engine() + Base.metadata.create_all(bind=engine) + ApiBase.metadata.create_all(bind=engine) + except Exception as e: + global _engine, _SessionLocal + print(f"WARNING: Database connection failed, skipping table creation: {e}") + if _engine is not None: + _engine.dispose() + _engine = None + _SessionLocal = None + + +def debate_session_from_db(db_session) -> DebateSession: + """Convert database session to Pydantic model.""" + evidence_library = [] + if db_session.evidence_library: + evidence_library = json.loads(db_session.evidence_library) + + return DebateSession( + session_id=db_session.session_id, + topic=db_session.topic, + participants=json.loads(db_session.participants), + constraints=json.loads(db_session.constraints), + rounds=json.loads(db_session.rounds), + status=db_session.status, + created_at=db_session.created_at, + completed_at=db_session.completed_at, + summary=db_session.summary, + evidence_library=evidence_library + ) + + +def debate_session_to_db(session: DebateSession) -> DebateSessionDB: + """Convert Pydantic model to database model.""" + return DebateSessionDB( + session_id=session.session_id, + topic=session.topic, + participants=json.dumps([p.dict() for p in session.participants]), + constraints=json.dumps(session.constraints.dict()), + rounds=json.dumps([r.dict() for r in session.rounds], default=str), + status=session.status, + created_at=session.created_at, + completed_at=session.completed_at, + summary=session.summary, + evidence_library=json.dumps([e.dict() for e in session.evidence_library], default=str) if session.evidence_library else None + ) diff --git a/storage/session_manager.py b/storage/session_manager.py new file mode 100644 index 0000000..2f650f2 --- /dev/null +++ b/storage/session_manager.py @@ -0,0 +1,67 @@ +from sqlalchemy.orm import Session +from datetime import datetime +import json +from typing import Optional + +from models.debate import DebateSession +from storage.database import DebateSessionDB, debate_session_from_db, debate_session_to_db + + +class SessionManager: + """ + Manages debate sessions in storage + """ + + def __init__(self): + pass + + async def save_session(self, db: Session, session: DebateSession): + """ + Save a debate session to the database + """ + db_session = debate_session_to_db(session) + db.add(db_session) + db.commit() + db.refresh(db_session) + + async def get_session(self, db: Session, session_id: str) -> Optional[DebateSession]: + """ + Retrieve a debate session from the database + """ + db_session = db.query(DebateSessionDB).filter(DebateSessionDB.session_id == session_id).first() + if not db_session: + return None + return debate_session_from_db(db_session) + + async def update_session(self, db: Session, session: DebateSession): + """ + Update an existing debate session in the database + """ + db_session = db.query(DebateSessionDB).filter(DebateSessionDB.session_id == session.session_id).first() + if db_session: + db_session.topic = session.topic + db_session.participants = json.dumps([p.dict() for p in session.participants]) + db_session.constraints = json.dumps(session.constraints.dict()) + db_session.rounds = json.dumps([r.dict() for r in session.rounds], default=str) + db_session.status = session.status + db_session.completed_at = session.completed_at + db_session.summary = session.summary + db_session.evidence_library = json.dumps([e.dict() for e in session.evidence_library], default=str) if session.evidence_library else None + + db.commit() + db.refresh(db_session) + + async def list_sessions(self, db: Session): + """ + List all debate sessions + """ + db_sessions = db.query(DebateSessionDB).all() + return [ + { + "session_id": db_session.session_id, + "topic": db_session.topic, + "status": db_session.status, + "created_at": db_session.created_at.isoformat() if db_session.created_at else None + } + for db_session in db_sessions + ] \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/summarizer.py b/utils/summarizer.py new file mode 100644 index 0000000..7188e83 --- /dev/null +++ b/utils/summarizer.py @@ -0,0 +1,39 @@ +async def summarize_debate(session): + """ + Generate a summary of the debate session + """ + if not session.rounds: + return "No rounds were completed in this debate." + + # Extract key points from each side + pro_points = [] + con_points = [] + + for round_data in session.rounds: + if round_data.stance.value == "pro": + pro_points.append(round_data.content) + else: + con_points.append(round_data.content) + + # Create a summary + summary_parts = [ + f"辩论主题: {session.topic}", + "", + "正方主要观点:", + ] + + for i, point in enumerate(pro_points, 1): + summary_parts.append(f"{i}. {point[:100]}...") # Truncate for brevity + + summary_parts.append("") + summary_parts.append("反方主要观点:") + + for i, point in enumerate(con_points, 1): + summary_parts.append(f"{i}. {point[:100]}...") # Truncate for brevity + + summary_parts.append("") + summary_parts.append("总结: 本次辩论完成了 {} 轮,双方就 '{}' 主题进行了充分的讨论。".format( + len(session.rounds), session.topic + )) + + return "\n".join(summary_parts) \ No newline at end of file