159 lines
5.6 KiB
Python
159 lines
5.6 KiB
Python
from fastapi import APIRouter, Depends, HTTPException
|
||
|
||
from middleware.auth import require_auth
|
||
from sse_starlette.sse import EventSourceResponse
|
||
from typing import Dict, Any
|
||
import asyncio
|
||
import json
|
||
|
||
from orchestrator.debate_orchestrator import DebateOrchestrator
|
||
from models.debate import DebateRequest
|
||
from storage.session_manager import SessionManager
|
||
from db_models import get_db
|
||
|
||
router = APIRouter(tags=["debates"])
|
||
|
||
|
||
@router.post("/debate/create", dependencies=[Depends(require_auth)])
|
||
async def create_debate(debate_request: DebateRequest, db=Depends(get_db)) -> Dict[str, Any]:
|
||
"""
|
||
Create a new debate session with specified parameters
|
||
"""
|
||
try:
|
||
orchestrator = DebateOrchestrator(db)
|
||
session_id = await orchestrator.create_session(debate_request)
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
"status": "created",
|
||
"message": f"Debate session {session_id} created successfully"
|
||
}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.get("/debate/{session_id}")
|
||
async def get_debate(session_id: str, db=Depends(get_db)) -> Dict[str, Any]:
|
||
"""
|
||
Get the current state of a debate session
|
||
"""
|
||
try:
|
||
orchestrator = DebateOrchestrator(db)
|
||
session = await orchestrator.get_session_status(session_id)
|
||
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Debate session {session_id} not found")
|
||
|
||
return session.dict()
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.post("/debate/{session_id}/start", dependencies=[Depends(require_auth)])
|
||
async def start_debate(session_id: str, db=Depends(get_db)) -> Dict[str, Any]:
|
||
"""
|
||
Start a debate session and stream the results
|
||
"""
|
||
try:
|
||
orchestrator = DebateOrchestrator(db)
|
||
session = await orchestrator.run_debate(session_id)
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
"status": session.status,
|
||
"message": f"Debate session {session_id} completed"
|
||
}
|
||
except Exception as e:
|
||
import traceback
|
||
traceback.print_exc()
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.delete("/debate/{session_id}", dependencies=[Depends(require_auth)])
|
||
async def end_debate(session_id: str, db=Depends(get_db)) -> Dict[str, Any]:
|
||
"""
|
||
End a debate session prematurely
|
||
"""
|
||
try:
|
||
orchestrator = DebateOrchestrator(db)
|
||
await orchestrator.terminate_session(session_id)
|
||
|
||
return {"session_id": session_id, "status": "terminated"}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.get("/debate/{session_id}/stream")
|
||
async def stream_debate(session_id: str, db=Depends(get_db)) -> EventSourceResponse:
|
||
"""
|
||
Stream debate updates in real-time using Server-Sent Events
|
||
"""
|
||
async def event_generator():
|
||
session_manager = SessionManager()
|
||
|
||
# Check if the session exists
|
||
session = await session_manager.get_session(db, session_id)
|
||
if not session:
|
||
yield {"event": "error", "data": json.dumps({"error": f"Session {session_id} not found"})}
|
||
return
|
||
|
||
# Yield initial state
|
||
yield {"event": "update", "data": json.dumps({
|
||
"session_id": session_id,
|
||
"status": session.status,
|
||
"rounds": [round.dict() for round in session.rounds],
|
||
"evidence_library": [e.dict() for e in session.evidence_library],
|
||
"message": "Debate session initialized"
|
||
}, default=str)}
|
||
|
||
last_round_count = len(session.rounds)
|
||
# Poll until debate completes or times out (max 5 min)
|
||
for _ in range(150): # 150 × 2s = 300s timeout
|
||
await asyncio.sleep(2)
|
||
|
||
# Reset transaction so we see commits from the run_debate request
|
||
db.commit()
|
||
|
||
updated_session = await session_manager.get_session(db, session_id)
|
||
if not updated_session:
|
||
break
|
||
|
||
# Only yield update when rounds actually changed
|
||
if len(updated_session.rounds) != last_round_count or updated_session.status != session.status:
|
||
last_round_count = len(updated_session.rounds)
|
||
session = updated_session
|
||
yield {"event": "update", "data": json.dumps({
|
||
"session_id": session_id,
|
||
"status": updated_session.status,
|
||
"rounds": [round.dict() for round in updated_session.rounds],
|
||
"evidence_library": [e.dict() for e in updated_session.evidence_library],
|
||
"current_round": len(updated_session.rounds)
|
||
}, default=str)}
|
||
|
||
if updated_session.status in ("completed", "terminated"):
|
||
yield {"event": "complete", "data": json.dumps({
|
||
"session_id": session_id,
|
||
"status": updated_session.status,
|
||
"summary": updated_session.summary,
|
||
"rounds": [round.dict() for round in updated_session.rounds],
|
||
"evidence_library": [e.dict() for e in updated_session.evidence_library]
|
||
}, default=str)}
|
||
break
|
||
|
||
return EventSourceResponse(event_generator())
|
||
|
||
|
||
@router.get("/sessions")
|
||
async def list_sessions(db=Depends(get_db)) -> Dict[str, Any]:
|
||
"""
|
||
List all debate sessions
|
||
"""
|
||
try:
|
||
session_manager = SessionManager()
|
||
sessions = await session_manager.list_sessions(db)
|
||
return {"sessions": sessions}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=str(e))
|