359 lines
14 KiB
Python
359 lines
14 KiB
Python
import asyncio
|
||
import uuid
|
||
from typing import Dict, List, Optional
|
||
from datetime import datetime
|
||
from sqlalchemy.orm import Session
|
||
|
||
from models.debate import (
|
||
DebateRequest, DebateSession, DebateRound,
|
||
DebateStance, DebateParticipant, SearchEvidence, SearchResult,
|
||
EvidenceEntry, EvidenceReference
|
||
)
|
||
from providers.provider_factory import ProviderFactory
|
||
from storage.session_manager import SessionManager
|
||
from utils.summarizer import summarize_debate
|
||
from services.search_service import SearchService
|
||
from services.api_key_service import ApiKeyService
|
||
|
||
|
||
class DebateOrchestrator:
|
||
"""
|
||
Orchestrates the debate between multiple language models
|
||
"""
|
||
|
||
def __init__(self, db: Session):
|
||
self.db = db
|
||
self.session_manager = SessionManager()
|
||
self.provider_factory = ProviderFactory()
|
||
|
||
async def create_session(self, debate_request: DebateRequest) -> str:
|
||
"""
|
||
Create a new debate session
|
||
"""
|
||
session_id = str(uuid.uuid4())
|
||
|
||
session = DebateSession(
|
||
session_id=session_id,
|
||
topic=debate_request.topic,
|
||
participants=debate_request.participants,
|
||
constraints=debate_request.constraints,
|
||
rounds=[],
|
||
status="active",
|
||
created_at=datetime.now()
|
||
)
|
||
|
||
await self.session_manager.save_session(self.db, session)
|
||
return session_id
|
||
|
||
def _get_search_service(self) -> Optional[SearchService]:
|
||
"""
|
||
Get a SearchService instance if Tavily API key is available.
|
||
"""
|
||
tavily_key = ApiKeyService.get_api_key(self.db, "tavily")
|
||
if tavily_key:
|
||
return SearchService(api_key=tavily_key)
|
||
return None
|
||
|
||
async def run_debate(self, session_id: str) -> DebateSession:
|
||
"""
|
||
Run the complete debate process
|
||
"""
|
||
session = await self.session_manager.get_session(self.db, session_id)
|
||
if not session:
|
||
raise ValueError(f"Session {session_id} not found")
|
||
|
||
# Initialize providers for each participant
|
||
providers = {}
|
||
for participant in session.participants:
|
||
provider = self.provider_factory.create_provider(
|
||
self.db,
|
||
participant.provider,
|
||
participant.api_key # This can be None, and the provider will fetch from DB
|
||
)
|
||
providers[participant.model_identifier] = provider
|
||
|
||
# Initialize search service if web search is enabled
|
||
search_service = None
|
||
web_search_enabled = session.constraints.web_search_enabled
|
||
web_search_mode = session.constraints.web_search_mode
|
||
if web_search_enabled:
|
||
search_service = self._get_search_service()
|
||
if not search_service:
|
||
print("Warning: Web search enabled but no Tavily API key found. Disabling search.")
|
||
web_search_enabled = False
|
||
|
||
# Run the debate rounds
|
||
for round_num in range(session.constraints.max_rounds):
|
||
if session.status != "active":
|
||
break
|
||
|
||
# Alternate between participants
|
||
current_participant = session.participants[round_num % len(session.participants)]
|
||
provider = providers[current_participant.model_identifier]
|
||
|
||
# Perform automatic search if enabled
|
||
search_evidence = None
|
||
if web_search_enabled and web_search_mode in ("auto", "both"):
|
||
search_evidence = self._perform_automatic_search(
|
||
search_service, session, round_num
|
||
)
|
||
|
||
# Prepare context for the current turn (with search results if available)
|
||
context = self._prepare_context(session, current_participant.stance, search_evidence)
|
||
|
||
# Determine if we should use tool calling for this round
|
||
use_tool_calling = (
|
||
web_search_enabled
|
||
and web_search_mode in ("tool", "both")
|
||
and provider.supports_tools()
|
||
)
|
||
|
||
if use_tool_calling:
|
||
response, tool_evidence = await self._handle_tool_calls(
|
||
provider, current_participant.model_identifier,
|
||
context, search_service
|
||
)
|
||
# Merge tool-based evidence with auto evidence
|
||
if tool_evidence:
|
||
if search_evidence:
|
||
search_evidence.results.extend(tool_evidence.results)
|
||
search_evidence.query += f" | {tool_evidence.query}"
|
||
search_evidence.mode = "both"
|
||
else:
|
||
search_evidence = tool_evidence
|
||
else:
|
||
response = await provider.generate_response(
|
||
model=current_participant.model_identifier,
|
||
prompt=context
|
||
)
|
||
|
||
# Clean the response to remove any echoed prompt/meta text
|
||
response = self._clean_response(response)
|
||
|
||
# Create a new round
|
||
round_data = DebateRound(
|
||
round_number=round_num + 1,
|
||
speaker=current_participant.model_identifier,
|
||
stance=current_participant.stance,
|
||
content=response,
|
||
timestamp=datetime.now(),
|
||
token_count=len(response.split()), # Approximate token count
|
||
search_evidence=search_evidence
|
||
)
|
||
|
||
session.rounds.append(round_data)
|
||
|
||
# Update evidence library with search results from this round
|
||
if round_data.search_evidence:
|
||
self._update_evidence_library(session, round_data)
|
||
|
||
# Update session in storage
|
||
await self.session_manager.update_session(self.db, session)
|
||
|
||
# Small delay between rounds to simulate realistic interaction
|
||
await asyncio.sleep(1)
|
||
|
||
# Generate summary after all rounds are complete
|
||
summary = await summarize_debate(session)
|
||
session.summary = summary
|
||
session.status = "completed"
|
||
session.completed_at = datetime.now()
|
||
|
||
await self.session_manager.update_session(self.db, session)
|
||
return session
|
||
|
||
def _perform_automatic_search(
|
||
self, search_service: SearchService, session: DebateSession, round_num: int
|
||
) -> Optional[SearchEvidence]:
|
||
"""
|
||
Perform an automatic web search based on the topic and last opponent argument.
|
||
"""
|
||
last_opponent_arg = None
|
||
if session.rounds:
|
||
last_opponent_arg = session.rounds[-1].content
|
||
|
||
query = SearchService.generate_search_query(session.topic, last_opponent_arg)
|
||
results = search_service.search(query, max_results=3)
|
||
|
||
if results:
|
||
return SearchEvidence(
|
||
query=query,
|
||
results=results,
|
||
mode="auto"
|
||
)
|
||
return None
|
||
|
||
async def _handle_tool_calls(
|
||
self, provider, model: str, context: str, search_service: SearchService
|
||
) -> tuple:
|
||
"""
|
||
Handle tool calling flow: send prompt with tools, execute any tool calls,
|
||
then re-prompt with results. Max 2 tool call iterations.
|
||
Returns (final_response_text, SearchEvidence_or_None).
|
||
"""
|
||
tools = [SearchService.get_tool_definition()]
|
||
all_search_results = []
|
||
all_queries = []
|
||
|
||
text, tool_calls = await provider.generate_response_with_tools(
|
||
model=model,
|
||
prompt=context,
|
||
tools=tools,
|
||
max_tokens=500
|
||
)
|
||
|
||
# If no tool calls, return the text response directly
|
||
if not tool_calls:
|
||
return text, None
|
||
|
||
# Process up to 2 rounds of tool calls
|
||
for iteration in range(2):
|
||
if not tool_calls:
|
||
break
|
||
|
||
# Execute each tool call
|
||
tool_results_text = []
|
||
for tc in tool_calls:
|
||
if tc["name"] == "web_search":
|
||
query = tc["arguments"].get("query", "")
|
||
all_queries.append(query)
|
||
results = search_service.search(query, max_results=3)
|
||
all_search_results.extend(results)
|
||
|
||
# Format results for the model
|
||
evidence = SearchEvidence(query=query, results=results, mode="tool")
|
||
tool_results_text.append(SearchService.format_results_for_context(evidence))
|
||
|
||
# Re-prompt the model with tool results
|
||
augmented_context = context + "\n" + "\n".join(tool_results_text)
|
||
augmented_context += "\n请基于以上搜索结果和辩论历史,给出你的论点。"
|
||
|
||
text, tool_calls = await provider.generate_response_with_tools(
|
||
model=model,
|
||
prompt=augmented_context,
|
||
tools=tools,
|
||
max_tokens=500
|
||
)
|
||
|
||
# Build combined evidence
|
||
evidence = None
|
||
if all_search_results:
|
||
evidence = SearchEvidence(
|
||
query=" | ".join(all_queries),
|
||
results=all_search_results,
|
||
mode="tool"
|
||
)
|
||
|
||
# If we still got no text (model keeps calling tools), fall back
|
||
if not text:
|
||
text = await provider.generate_response(model=model, prompt=context, max_tokens=500)
|
||
|
||
return text, evidence
|
||
|
||
def _update_evidence_library(self, session: DebateSession, round_data: DebateRound):
|
||
"""
|
||
Merge search results from a round into the session's evidence library, deduplicating by URL.
|
||
"""
|
||
ref = EvidenceReference(
|
||
round_number=round_data.round_number,
|
||
speaker=round_data.speaker,
|
||
stance=round_data.stance
|
||
)
|
||
|
||
url_index = {entry.url: i for i, entry in enumerate(session.evidence_library)}
|
||
|
||
for result in round_data.search_evidence.results:
|
||
if result.url in url_index:
|
||
entry = session.evidence_library[url_index[result.url]]
|
||
# Avoid duplicate references (same round + speaker)
|
||
if not any(
|
||
r.round_number == ref.round_number and r.speaker == ref.speaker
|
||
for r in entry.references
|
||
):
|
||
entry.references.append(ref)
|
||
else:
|
||
new_entry = EvidenceEntry(
|
||
title=result.title,
|
||
url=result.url,
|
||
snippet=result.snippet,
|
||
score=result.score,
|
||
references=[ref]
|
||
)
|
||
session.evidence_library.append(new_entry)
|
||
url_index[result.url] = len(session.evidence_library) - 1
|
||
|
||
def _prepare_context(
|
||
self, session: DebateSession, current_stance: DebateStance,
|
||
search_evidence: Optional[SearchEvidence] = None
|
||
) -> str:
|
||
"""
|
||
Prepare the context/prompt for the current model turn
|
||
"""
|
||
# Determine the stance of the current speaker
|
||
if current_stance == DebateStance.PRO:
|
||
position_desc = "正方(支持方)"
|
||
opposing_desc = "反方(反对方)"
|
||
else:
|
||
position_desc = "反方(反对方)"
|
||
opposing_desc = "正方(支持方)"
|
||
|
||
# Build the context with previous rounds
|
||
context_parts = [
|
||
f"辩论主题: {session.topic}",
|
||
f"你的立场: {position_desc}",
|
||
"辩论规则:",
|
||
"- 必须回应对方上一轮的核心论点",
|
||
"- 不得重复自己已提出的观点",
|
||
"- 输出长度限制在合理范围内",
|
||
"\n历史辩论记录:"
|
||
]
|
||
|
||
for round_data in session.rounds:
|
||
stance_text = "正方" if round_data.stance == DebateStance.PRO else "反方"
|
||
context_parts.append(f"第{round_data.round_number}轮 - {stance_text}: {round_data.content}")
|
||
|
||
# Inject search results if available
|
||
if search_evidence:
|
||
context_parts.append(SearchService.format_results_for_context(search_evidence))
|
||
|
||
context_parts.append(f"\n现在轮到你 ({position_desc}) 发言,请基于以上内容进行回应。注意:直接给出你的论点内容,不要重复上述提示词、辩论规则或历史记录。")
|
||
|
||
return "\n".join(context_parts)
|
||
|
||
def _clean_response(self, response: str) -> str:
|
||
"""
|
||
Clean the model response to remove any echoed prompt/meta text
|
||
"""
|
||
import re
|
||
|
||
# Remove common prompt echoes and meta prefixes
|
||
patterns_to_remove = [
|
||
r'^第\d+轮\s*[-::]\s*(正方|反方)\s*[::]?\s*', # 第X轮 - 正方/反方:
|
||
r'^(正方|反方)\s*[((][^))]*[))]\s*[::]?\s*', # 正方(支持方):
|
||
r'^(正方|反方)\s*[::]\s*', # 正方: or 反方:
|
||
r'^我的立场\s*[::]\s*', # 我的立场:
|
||
r'^回应\s*[::]\s*', # 回应:
|
||
r'^辩论发言\s*[::]\s*', # 辩论发言:
|
||
]
|
||
|
||
cleaned = response.strip()
|
||
for pattern in patterns_to_remove:
|
||
cleaned = re.sub(pattern, '', cleaned, flags=re.MULTILINE)
|
||
|
||
return cleaned.strip()
|
||
|
||
async def get_session_status(self, session_id: str) -> Optional[DebateSession]:
|
||
"""
|
||
Get the current status of a debate session
|
||
"""
|
||
return await self.session_manager.get_session(self.db, session_id)
|
||
|
||
async def terminate_session(self, session_id: str):
|
||
"""
|
||
Terminate a debate session prematurely
|
||
"""
|
||
session = await self.session_manager.get_session(self.db, session_id)
|
||
if session:
|
||
session.status = "terminated"
|
||
await self.session_manager.update_session(self.db, session)
|