Files
Dialectic.Backend/orchestrator/debate_orchestrator.py
2026-02-12 15:45:48 +00:00

359 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import asyncio
import uuid
from typing import Dict, List, Optional
from datetime import datetime
from sqlalchemy.orm import Session
from models.debate import (
DebateRequest, DebateSession, DebateRound,
DebateStance, DebateParticipant, SearchEvidence, SearchResult,
EvidenceEntry, EvidenceReference
)
from providers.provider_factory import ProviderFactory
from storage.session_manager import SessionManager
from utils.summarizer import summarize_debate
from services.search_service import SearchService
from services.api_key_service import ApiKeyService
class DebateOrchestrator:
"""
Orchestrates the debate between multiple language models
"""
def __init__(self, db: Session):
self.db = db
self.session_manager = SessionManager()
self.provider_factory = ProviderFactory()
async def create_session(self, debate_request: DebateRequest) -> str:
"""
Create a new debate session
"""
session_id = str(uuid.uuid4())
session = DebateSession(
session_id=session_id,
topic=debate_request.topic,
participants=debate_request.participants,
constraints=debate_request.constraints,
rounds=[],
status="active",
created_at=datetime.now()
)
await self.session_manager.save_session(self.db, session)
return session_id
def _get_search_service(self) -> Optional[SearchService]:
"""
Get a SearchService instance if Tavily API key is available.
"""
tavily_key = ApiKeyService.get_api_key(self.db, "tavily")
if tavily_key:
return SearchService(api_key=tavily_key)
return None
async def run_debate(self, session_id: str) -> DebateSession:
"""
Run the complete debate process
"""
session = await self.session_manager.get_session(self.db, session_id)
if not session:
raise ValueError(f"Session {session_id} not found")
# Initialize providers for each participant
providers = {}
for participant in session.participants:
provider = self.provider_factory.create_provider(
self.db,
participant.provider,
participant.api_key # This can be None, and the provider will fetch from DB
)
providers[participant.model_identifier] = provider
# Initialize search service if web search is enabled
search_service = None
web_search_enabled = session.constraints.web_search_enabled
web_search_mode = session.constraints.web_search_mode
if web_search_enabled:
search_service = self._get_search_service()
if not search_service:
print("Warning: Web search enabled but no Tavily API key found. Disabling search.")
web_search_enabled = False
# Run the debate rounds
for round_num in range(session.constraints.max_rounds):
if session.status != "active":
break
# Alternate between participants
current_participant = session.participants[round_num % len(session.participants)]
provider = providers[current_participant.model_identifier]
# Perform automatic search if enabled
search_evidence = None
if web_search_enabled and web_search_mode in ("auto", "both"):
search_evidence = self._perform_automatic_search(
search_service, session, round_num
)
# Prepare context for the current turn (with search results if available)
context = self._prepare_context(session, current_participant.stance, search_evidence)
# Determine if we should use tool calling for this round
use_tool_calling = (
web_search_enabled
and web_search_mode in ("tool", "both")
and provider.supports_tools()
)
if use_tool_calling:
response, tool_evidence = await self._handle_tool_calls(
provider, current_participant.model_identifier,
context, search_service
)
# Merge tool-based evidence with auto evidence
if tool_evidence:
if search_evidence:
search_evidence.results.extend(tool_evidence.results)
search_evidence.query += f" | {tool_evidence.query}"
search_evidence.mode = "both"
else:
search_evidence = tool_evidence
else:
response = await provider.generate_response(
model=current_participant.model_identifier,
prompt=context
)
# Clean the response to remove any echoed prompt/meta text
response = self._clean_response(response)
# Create a new round
round_data = DebateRound(
round_number=round_num + 1,
speaker=current_participant.model_identifier,
stance=current_participant.stance,
content=response,
timestamp=datetime.now(),
token_count=len(response.split()), # Approximate token count
search_evidence=search_evidence
)
session.rounds.append(round_data)
# Update evidence library with search results from this round
if round_data.search_evidence:
self._update_evidence_library(session, round_data)
# Update session in storage
await self.session_manager.update_session(self.db, session)
# Small delay between rounds to simulate realistic interaction
await asyncio.sleep(1)
# Generate summary after all rounds are complete
summary = await summarize_debate(session)
session.summary = summary
session.status = "completed"
session.completed_at = datetime.now()
await self.session_manager.update_session(self.db, session)
return session
def _perform_automatic_search(
self, search_service: SearchService, session: DebateSession, round_num: int
) -> Optional[SearchEvidence]:
"""
Perform an automatic web search based on the topic and last opponent argument.
"""
last_opponent_arg = None
if session.rounds:
last_opponent_arg = session.rounds[-1].content
query = SearchService.generate_search_query(session.topic, last_opponent_arg)
results = search_service.search(query, max_results=3)
if results:
return SearchEvidence(
query=query,
results=results,
mode="auto"
)
return None
async def _handle_tool_calls(
self, provider, model: str, context: str, search_service: SearchService
) -> tuple:
"""
Handle tool calling flow: send prompt with tools, execute any tool calls,
then re-prompt with results. Max 2 tool call iterations.
Returns (final_response_text, SearchEvidence_or_None).
"""
tools = [SearchService.get_tool_definition()]
all_search_results = []
all_queries = []
text, tool_calls = await provider.generate_response_with_tools(
model=model,
prompt=context,
tools=tools,
max_tokens=500
)
# If no tool calls, return the text response directly
if not tool_calls:
return text, None
# Process up to 2 rounds of tool calls
for iteration in range(2):
if not tool_calls:
break
# Execute each tool call
tool_results_text = []
for tc in tool_calls:
if tc["name"] == "web_search":
query = tc["arguments"].get("query", "")
all_queries.append(query)
results = search_service.search(query, max_results=3)
all_search_results.extend(results)
# Format results for the model
evidence = SearchEvidence(query=query, results=results, mode="tool")
tool_results_text.append(SearchService.format_results_for_context(evidence))
# Re-prompt the model with tool results
augmented_context = context + "\n" + "\n".join(tool_results_text)
augmented_context += "\n请基于以上搜索结果和辩论历史,给出你的论点。"
text, tool_calls = await provider.generate_response_with_tools(
model=model,
prompt=augmented_context,
tools=tools,
max_tokens=500
)
# Build combined evidence
evidence = None
if all_search_results:
evidence = SearchEvidence(
query=" | ".join(all_queries),
results=all_search_results,
mode="tool"
)
# If we still got no text (model keeps calling tools), fall back
if not text:
text = await provider.generate_response(model=model, prompt=context, max_tokens=500)
return text, evidence
def _update_evidence_library(self, session: DebateSession, round_data: DebateRound):
"""
Merge search results from a round into the session's evidence library, deduplicating by URL.
"""
ref = EvidenceReference(
round_number=round_data.round_number,
speaker=round_data.speaker,
stance=round_data.stance
)
url_index = {entry.url: i for i, entry in enumerate(session.evidence_library)}
for result in round_data.search_evidence.results:
if result.url in url_index:
entry = session.evidence_library[url_index[result.url]]
# Avoid duplicate references (same round + speaker)
if not any(
r.round_number == ref.round_number and r.speaker == ref.speaker
for r in entry.references
):
entry.references.append(ref)
else:
new_entry = EvidenceEntry(
title=result.title,
url=result.url,
snippet=result.snippet,
score=result.score,
references=[ref]
)
session.evidence_library.append(new_entry)
url_index[result.url] = len(session.evidence_library) - 1
def _prepare_context(
self, session: DebateSession, current_stance: DebateStance,
search_evidence: Optional[SearchEvidence] = None
) -> str:
"""
Prepare the context/prompt for the current model turn
"""
# Determine the stance of the current speaker
if current_stance == DebateStance.PRO:
position_desc = "正方(支持方)"
opposing_desc = "反方(反对方)"
else:
position_desc = "反方(反对方)"
opposing_desc = "正方(支持方)"
# Build the context with previous rounds
context_parts = [
f"辩论主题: {session.topic}",
f"你的立场: {position_desc}",
"辩论规则:",
"- 必须回应对方上一轮的核心论点",
"- 不得重复自己已提出的观点",
"- 输出长度限制在合理范围内",
"\n历史辩论记录:"
]
for round_data in session.rounds:
stance_text = "正方" if round_data.stance == DebateStance.PRO else "反方"
context_parts.append(f"{round_data.round_number}轮 - {stance_text}: {round_data.content}")
# Inject search results if available
if search_evidence:
context_parts.append(SearchService.format_results_for_context(search_evidence))
context_parts.append(f"\n现在轮到你 ({position_desc}) 发言,请基于以上内容进行回应。注意:直接给出你的论点内容,不要重复上述提示词、辩论规则或历史记录。")
return "\n".join(context_parts)
def _clean_response(self, response: str) -> str:
"""
Clean the model response to remove any echoed prompt/meta text
"""
import re
# Remove common prompt echoes and meta prefixes
patterns_to_remove = [
r'^第\d+轮\s*[-:]\s*(正方|反方)\s*[:]?\s*', # 第X轮 - 正方/反方:
r'^(正方|反方)\s*[(][^)]*[)]\s*[:]?\s*', # 正方(支持方):
r'^(正方|反方)\s*[:]\s*', # 正方: or 反方:
r'^我的立场\s*[:]\s*', # 我的立场:
r'^回应\s*[:]\s*', # 回应:
r'^辩论发言\s*[:]\s*', # 辩论发言:
]
cleaned = response.strip()
for pattern in patterns_to_remove:
cleaned = re.sub(pattern, '', cleaned, flags=re.MULTILINE)
return cleaned.strip()
async def get_session_status(self, session_id: str) -> Optional[DebateSession]:
"""
Get the current status of a debate session
"""
return await self.session_manager.get_session(self.db, session_id)
async def terminate_session(self, session_id: str):
"""
Terminate a debate session prematurely
"""
session = await self.session_manager.get_session(self.db, session_id)
if session:
session.status = "terminated"
await self.session_manager.update_session(self.db, session)