init
This commit is contained in:
358
orchestrator/debate_orchestrator.py
Normal file
358
orchestrator/debate_orchestrator.py
Normal file
@@ -0,0 +1,358 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import Dict, List, Optional
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.debate import (
|
||||
DebateRequest, DebateSession, DebateRound,
|
||||
DebateStance, DebateParticipant, SearchEvidence, SearchResult,
|
||||
EvidenceEntry, EvidenceReference
|
||||
)
|
||||
from providers.provider_factory import ProviderFactory
|
||||
from storage.session_manager import SessionManager
|
||||
from utils.summarizer import summarize_debate
|
||||
from services.search_service import SearchService
|
||||
from services.api_key_service import ApiKeyService
|
||||
|
||||
|
||||
class DebateOrchestrator:
|
||||
"""
|
||||
Orchestrates the debate between multiple language models
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.session_manager = SessionManager()
|
||||
self.provider_factory = ProviderFactory()
|
||||
|
||||
async def create_session(self, debate_request: DebateRequest) -> str:
|
||||
"""
|
||||
Create a new debate session
|
||||
"""
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
session = DebateSession(
|
||||
session_id=session_id,
|
||||
topic=debate_request.topic,
|
||||
participants=debate_request.participants,
|
||||
constraints=debate_request.constraints,
|
||||
rounds=[],
|
||||
status="active",
|
||||
created_at=datetime.now()
|
||||
)
|
||||
|
||||
await self.session_manager.save_session(self.db, session)
|
||||
return session_id
|
||||
|
||||
def _get_search_service(self) -> Optional[SearchService]:
|
||||
"""
|
||||
Get a SearchService instance if Tavily API key is available.
|
||||
"""
|
||||
tavily_key = ApiKeyService.get_api_key(self.db, "tavily")
|
||||
if tavily_key:
|
||||
return SearchService(api_key=tavily_key)
|
||||
return None
|
||||
|
||||
async def run_debate(self, session_id: str) -> DebateSession:
|
||||
"""
|
||||
Run the complete debate process
|
||||
"""
|
||||
session = await self.session_manager.get_session(self.db, session_id)
|
||||
if not session:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
|
||||
# Initialize providers for each participant
|
||||
providers = {}
|
||||
for participant in session.participants:
|
||||
provider = self.provider_factory.create_provider(
|
||||
self.db,
|
||||
participant.provider,
|
||||
participant.api_key # This can be None, and the provider will fetch from DB
|
||||
)
|
||||
providers[participant.model_identifier] = provider
|
||||
|
||||
# Initialize search service if web search is enabled
|
||||
search_service = None
|
||||
web_search_enabled = session.constraints.web_search_enabled
|
||||
web_search_mode = session.constraints.web_search_mode
|
||||
if web_search_enabled:
|
||||
search_service = self._get_search_service()
|
||||
if not search_service:
|
||||
print("Warning: Web search enabled but no Tavily API key found. Disabling search.")
|
||||
web_search_enabled = False
|
||||
|
||||
# Run the debate rounds
|
||||
for round_num in range(session.constraints.max_rounds):
|
||||
if session.status != "active":
|
||||
break
|
||||
|
||||
# Alternate between participants
|
||||
current_participant = session.participants[round_num % len(session.participants)]
|
||||
provider = providers[current_participant.model_identifier]
|
||||
|
||||
# Perform automatic search if enabled
|
||||
search_evidence = None
|
||||
if web_search_enabled and web_search_mode in ("auto", "both"):
|
||||
search_evidence = self._perform_automatic_search(
|
||||
search_service, session, round_num
|
||||
)
|
||||
|
||||
# Prepare context for the current turn (with search results if available)
|
||||
context = self._prepare_context(session, current_participant.stance, search_evidence)
|
||||
|
||||
# Determine if we should use tool calling for this round
|
||||
use_tool_calling = (
|
||||
web_search_enabled
|
||||
and web_search_mode in ("tool", "both")
|
||||
and provider.supports_tools()
|
||||
)
|
||||
|
||||
if use_tool_calling:
|
||||
response, tool_evidence = await self._handle_tool_calls(
|
||||
provider, current_participant.model_identifier,
|
||||
context, search_service
|
||||
)
|
||||
# Merge tool-based evidence with auto evidence
|
||||
if tool_evidence:
|
||||
if search_evidence:
|
||||
search_evidence.results.extend(tool_evidence.results)
|
||||
search_evidence.query += f" | {tool_evidence.query}"
|
||||
search_evidence.mode = "both"
|
||||
else:
|
||||
search_evidence = tool_evidence
|
||||
else:
|
||||
response = await provider.generate_response(
|
||||
model=current_participant.model_identifier,
|
||||
prompt=context
|
||||
)
|
||||
|
||||
# Clean the response to remove any echoed prompt/meta text
|
||||
response = self._clean_response(response)
|
||||
|
||||
# Create a new round
|
||||
round_data = DebateRound(
|
||||
round_number=round_num + 1,
|
||||
speaker=current_participant.model_identifier,
|
||||
stance=current_participant.stance,
|
||||
content=response,
|
||||
timestamp=datetime.now(),
|
||||
token_count=len(response.split()), # Approximate token count
|
||||
search_evidence=search_evidence
|
||||
)
|
||||
|
||||
session.rounds.append(round_data)
|
||||
|
||||
# Update evidence library with search results from this round
|
||||
if round_data.search_evidence:
|
||||
self._update_evidence_library(session, round_data)
|
||||
|
||||
# Update session in storage
|
||||
await self.session_manager.update_session(self.db, session)
|
||||
|
||||
# Small delay between rounds to simulate realistic interaction
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Generate summary after all rounds are complete
|
||||
summary = await summarize_debate(session)
|
||||
session.summary = summary
|
||||
session.status = "completed"
|
||||
session.completed_at = datetime.now()
|
||||
|
||||
await self.session_manager.update_session(self.db, session)
|
||||
return session
|
||||
|
||||
def _perform_automatic_search(
|
||||
self, search_service: SearchService, session: DebateSession, round_num: int
|
||||
) -> Optional[SearchEvidence]:
|
||||
"""
|
||||
Perform an automatic web search based on the topic and last opponent argument.
|
||||
"""
|
||||
last_opponent_arg = None
|
||||
if session.rounds:
|
||||
last_opponent_arg = session.rounds[-1].content
|
||||
|
||||
query = SearchService.generate_search_query(session.topic, last_opponent_arg)
|
||||
results = search_service.search(query, max_results=3)
|
||||
|
||||
if results:
|
||||
return SearchEvidence(
|
||||
query=query,
|
||||
results=results,
|
||||
mode="auto"
|
||||
)
|
||||
return None
|
||||
|
||||
async def _handle_tool_calls(
|
||||
self, provider, model: str, context: str, search_service: SearchService
|
||||
) -> tuple:
|
||||
"""
|
||||
Handle tool calling flow: send prompt with tools, execute any tool calls,
|
||||
then re-prompt with results. Max 2 tool call iterations.
|
||||
Returns (final_response_text, SearchEvidence_or_None).
|
||||
"""
|
||||
tools = [SearchService.get_tool_definition()]
|
||||
all_search_results = []
|
||||
all_queries = []
|
||||
|
||||
text, tool_calls = await provider.generate_response_with_tools(
|
||||
model=model,
|
||||
prompt=context,
|
||||
tools=tools,
|
||||
max_tokens=500
|
||||
)
|
||||
|
||||
# If no tool calls, return the text response directly
|
||||
if not tool_calls:
|
||||
return text, None
|
||||
|
||||
# Process up to 2 rounds of tool calls
|
||||
for iteration in range(2):
|
||||
if not tool_calls:
|
||||
break
|
||||
|
||||
# Execute each tool call
|
||||
tool_results_text = []
|
||||
for tc in tool_calls:
|
||||
if tc["name"] == "web_search":
|
||||
query = tc["arguments"].get("query", "")
|
||||
all_queries.append(query)
|
||||
results = search_service.search(query, max_results=3)
|
||||
all_search_results.extend(results)
|
||||
|
||||
# Format results for the model
|
||||
evidence = SearchEvidence(query=query, results=results, mode="tool")
|
||||
tool_results_text.append(SearchService.format_results_for_context(evidence))
|
||||
|
||||
# Re-prompt the model with tool results
|
||||
augmented_context = context + "\n" + "\n".join(tool_results_text)
|
||||
augmented_context += "\n请基于以上搜索结果和辩论历史,给出你的论点。"
|
||||
|
||||
text, tool_calls = await provider.generate_response_with_tools(
|
||||
model=model,
|
||||
prompt=augmented_context,
|
||||
tools=tools,
|
||||
max_tokens=500
|
||||
)
|
||||
|
||||
# Build combined evidence
|
||||
evidence = None
|
||||
if all_search_results:
|
||||
evidence = SearchEvidence(
|
||||
query=" | ".join(all_queries),
|
||||
results=all_search_results,
|
||||
mode="tool"
|
||||
)
|
||||
|
||||
# If we still got no text (model keeps calling tools), fall back
|
||||
if not text:
|
||||
text = await provider.generate_response(model=model, prompt=context, max_tokens=500)
|
||||
|
||||
return text, evidence
|
||||
|
||||
def _update_evidence_library(self, session: DebateSession, round_data: DebateRound):
|
||||
"""
|
||||
Merge search results from a round into the session's evidence library, deduplicating by URL.
|
||||
"""
|
||||
ref = EvidenceReference(
|
||||
round_number=round_data.round_number,
|
||||
speaker=round_data.speaker,
|
||||
stance=round_data.stance
|
||||
)
|
||||
|
||||
url_index = {entry.url: i for i, entry in enumerate(session.evidence_library)}
|
||||
|
||||
for result in round_data.search_evidence.results:
|
||||
if result.url in url_index:
|
||||
entry = session.evidence_library[url_index[result.url]]
|
||||
# Avoid duplicate references (same round + speaker)
|
||||
if not any(
|
||||
r.round_number == ref.round_number and r.speaker == ref.speaker
|
||||
for r in entry.references
|
||||
):
|
||||
entry.references.append(ref)
|
||||
else:
|
||||
new_entry = EvidenceEntry(
|
||||
title=result.title,
|
||||
url=result.url,
|
||||
snippet=result.snippet,
|
||||
score=result.score,
|
||||
references=[ref]
|
||||
)
|
||||
session.evidence_library.append(new_entry)
|
||||
url_index[result.url] = len(session.evidence_library) - 1
|
||||
|
||||
def _prepare_context(
|
||||
self, session: DebateSession, current_stance: DebateStance,
|
||||
search_evidence: Optional[SearchEvidence] = None
|
||||
) -> str:
|
||||
"""
|
||||
Prepare the context/prompt for the current model turn
|
||||
"""
|
||||
# Determine the stance of the current speaker
|
||||
if current_stance == DebateStance.PRO:
|
||||
position_desc = "正方(支持方)"
|
||||
opposing_desc = "反方(反对方)"
|
||||
else:
|
||||
position_desc = "反方(反对方)"
|
||||
opposing_desc = "正方(支持方)"
|
||||
|
||||
# Build the context with previous rounds
|
||||
context_parts = [
|
||||
f"辩论主题: {session.topic}",
|
||||
f"你的立场: {position_desc}",
|
||||
"辩论规则:",
|
||||
"- 必须回应对方上一轮的核心论点",
|
||||
"- 不得重复自己已提出的观点",
|
||||
"- 输出长度限制在合理范围内",
|
||||
"\n历史辩论记录:"
|
||||
]
|
||||
|
||||
for round_data in session.rounds:
|
||||
stance_text = "正方" if round_data.stance == DebateStance.PRO else "反方"
|
||||
context_parts.append(f"第{round_data.round_number}轮 - {stance_text}: {round_data.content}")
|
||||
|
||||
# Inject search results if available
|
||||
if search_evidence:
|
||||
context_parts.append(SearchService.format_results_for_context(search_evidence))
|
||||
|
||||
context_parts.append(f"\n现在轮到你 ({position_desc}) 发言,请基于以上内容进行回应。注意:直接给出你的论点内容,不要重复上述提示词、辩论规则或历史记录。")
|
||||
|
||||
return "\n".join(context_parts)
|
||||
|
||||
def _clean_response(self, response: str) -> str:
|
||||
"""
|
||||
Clean the model response to remove any echoed prompt/meta text
|
||||
"""
|
||||
import re
|
||||
|
||||
# Remove common prompt echoes and meta prefixes
|
||||
patterns_to_remove = [
|
||||
r'^第\d+轮\s*[-::]\s*(正方|反方)\s*[::]?\s*', # 第X轮 - 正方/反方:
|
||||
r'^(正方|反方)\s*[((][^))]*[))]\s*[::]?\s*', # 正方(支持方):
|
||||
r'^(正方|反方)\s*[::]\s*', # 正方: or 反方:
|
||||
r'^我的立场\s*[::]\s*', # 我的立场:
|
||||
r'^回应\s*[::]\s*', # 回应:
|
||||
r'^辩论发言\s*[::]\s*', # 辩论发言:
|
||||
]
|
||||
|
||||
cleaned = response.strip()
|
||||
for pattern in patterns_to_remove:
|
||||
cleaned = re.sub(pattern, '', cleaned, flags=re.MULTILINE)
|
||||
|
||||
return cleaned.strip()
|
||||
|
||||
async def get_session_status(self, session_id: str) -> Optional[DebateSession]:
|
||||
"""
|
||||
Get the current status of a debate session
|
||||
"""
|
||||
return await self.session_manager.get_session(self.db, session_id)
|
||||
|
||||
async def terminate_session(self, session_id: str):
|
||||
"""
|
||||
Terminate a debate session prematurely
|
||||
"""
|
||||
session = await self.session_manager.get_session(self.db, session_id)
|
||||
if session:
|
||||
session.status = "terminated"
|
||||
await self.session_manager.update_session(self.db, session)
|
||||
Reference in New Issue
Block a user