49 lines
1.5 KiB
Python
49 lines
1.5 KiB
Python
from abc import ABC, abstractmethod
|
|
from typing import Optional
|
|
from sqlalchemy.orm import Session
|
|
|
|
|
|
class LLMProvider(ABC):
|
|
"""
|
|
Abstract base class for LLM providers
|
|
"""
|
|
|
|
@abstractmethod
|
|
def __init__(self, db: Session, api_key: Optional[str] = None):
|
|
"""
|
|
Initialize the provider with a database session and optional API key
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def generate_response(self, model: str, prompt: str, max_tokens: Optional[int] = None) -> str:
|
|
"""
|
|
Generate a response from the LLM
|
|
"""
|
|
pass
|
|
|
|
|
|
class ProviderFactory:
|
|
"""
|
|
Factory class to create provider instances
|
|
"""
|
|
|
|
@staticmethod
|
|
def create_provider(db: Session, provider_type: str, api_key: Optional[str] = None):
|
|
"""
|
|
Create a provider instance based on the type
|
|
"""
|
|
if provider_type.value == "openai":
|
|
from providers.openai_provider import OpenAIProvider
|
|
return OpenAIProvider(db, api_key)
|
|
elif provider_type.value == "claude":
|
|
from providers.claude_provider import ClaudeProvider
|
|
return ClaudeProvider(db, api_key)
|
|
elif provider_type.value == "qwen":
|
|
from providers.qwen_provider import QwenProvider
|
|
return QwenProvider(db, api_key)
|
|
elif provider_type.value == "deepseek":
|
|
from providers.deepseek_provider import DeepSeekProvider
|
|
return DeepSeekProvider(db, api_key)
|
|
else:
|
|
raise ValueError(f"Unsupported provider type: {provider_type}") |