76 lines
2.7 KiB
Python
76 lines
2.7 KiB
Python
import json
|
|
from typing import Optional, List, Dict, Any, Tuple
|
|
from sqlalchemy.orm import Session
|
|
from openai import AsyncOpenAI
|
|
|
|
from providers.base_provider import LLMProvider
|
|
from services.api_key_service import ApiKeyService
|
|
|
|
SYSTEM_PROMPT = "你正在参与一场结构化辩论。请按照用户消息中的规则进行辩论,直接给出你的论点,不要重复提示词或历史记录。"
|
|
|
|
|
|
class QwenProvider(LLMProvider):
|
|
"""
|
|
Qwen API provider implementation using DashScope OpenAI-compatible API
|
|
"""
|
|
|
|
def __init__(self, db: Session, api_key: Optional[str] = None):
|
|
if not api_key:
|
|
api_key = ApiKeyService.get_api_key(db, "qwen")
|
|
|
|
if not api_key:
|
|
raise ValueError("Qwen API key not found in database or provided")
|
|
|
|
self.client = AsyncOpenAI(
|
|
api_key=api_key,
|
|
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
|
|
)
|
|
|
|
def supports_tools(self) -> bool:
|
|
return True
|
|
|
|
async def generate_response(self, model: str, prompt: str, max_tokens: Optional[int] = None) -> str:
|
|
try:
|
|
response = await self.client.chat.completions.create(
|
|
model=model,
|
|
messages=[
|
|
{"role": "system", "content": SYSTEM_PROMPT},
|
|
{"role": "user", "content": prompt}
|
|
],
|
|
max_tokens=max_tokens or 500
|
|
)
|
|
return response.choices[0].message.content.strip()
|
|
except Exception as e:
|
|
raise Exception(f"Error calling Qwen API: {str(e)}")
|
|
|
|
async def generate_response_with_tools(
|
|
self,
|
|
model: str,
|
|
prompt: str,
|
|
tools: List[Dict[str, Any]],
|
|
max_tokens: Optional[int] = None
|
|
) -> Tuple[str, List[Dict[str, Any]]]:
|
|
try:
|
|
response = await self.client.chat.completions.create(
|
|
model=model,
|
|
messages=[
|
|
{"role": "system", "content": SYSTEM_PROMPT},
|
|
{"role": "user", "content": prompt}
|
|
],
|
|
tools=tools,
|
|
tool_choice="auto",
|
|
max_tokens=max_tokens or 500
|
|
)
|
|
message = response.choices[0].message
|
|
text_content = message.content or ""
|
|
tool_calls = []
|
|
if message.tool_calls:
|
|
for tc in message.tool_calls:
|
|
tool_calls.append({
|
|
"name": tc.function.name,
|
|
"arguments": json.loads(tc.function.arguments)
|
|
})
|
|
return text_content.strip(), tool_calls
|
|
except Exception as e:
|
|
raise Exception(f"Error calling Qwen API with tools: {str(e)}")
|