Files
Dialectic.Backend/providers/base_provider.py
2026-02-12 15:45:48 +00:00

73 lines
2.4 KiB
Python

from abc import ABC, abstractmethod
from typing import Optional, List, Dict, Any, Tuple
from sqlalchemy.orm import Session
class LLMProvider(ABC):
"""
Abstract base class for LLM providers
"""
@abstractmethod
def __init__(self, db: Session, api_key: Optional[str] = None):
"""
Initialize the provider with a database session and optional API key
"""
pass
@abstractmethod
async def generate_response(self, model: str, prompt: str, max_tokens: Optional[int] = None) -> str:
"""
Generate a response from the LLM
"""
pass
def supports_tools(self) -> bool:
"""
Whether this provider supports function/tool calling.
Override in subclasses that support it.
"""
return False
async def generate_response_with_tools(
self,
model: str,
prompt: str,
tools: List[Dict[str, Any]],
max_tokens: Optional[int] = None
) -> Tuple[str, List[Dict[str, Any]]]:
"""
Generate a response with tool definitions available.
Returns (text_content, tool_calls) where tool_calls is a list of
dicts with keys: name, arguments (dict).
If no tools are called, tool_calls is empty and text_content has the response.
Default implementation falls back to regular generation.
"""
text = await self.generate_response(model, prompt, max_tokens)
return text, []
class ProviderFactory:
"""
Factory class to create provider instances
"""
@staticmethod
def create_provider(db: Session, provider_type: str, api_key: Optional[str] = None):
"""
Create a provider instance based on the type
"""
if provider_type.value == "openai":
from app.providers.openai_provider import OpenAIProvider
return OpenAIProvider(db, api_key)
elif provider_type.value == "claude":
from app.providers.claude_provider import ClaudeProvider
return ClaudeProvider(db, api_key)
elif provider_type.value == "qwen":
from app.providers.qwen_provider import QwenProvider
return QwenProvider(db, api_key)
elif provider_type.value == "deepseek":
from app.providers.deepseek_provider import DeepSeekProvider
return DeepSeekProvider(db, api_key)
else:
raise ValueError(f"Unsupported provider type: {provider_type}")