init
This commit is contained in:
0
services/__init__.py
Normal file
0
services/__init__.py
Normal file
71
services/api_key_service.py
Normal file
71
services/api_key_service.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from cryptography.fernet import Fernet, InvalidToken
|
||||
from sqlalchemy.orm import Session
|
||||
from db_models import ApiKey
|
||||
import os
|
||||
|
||||
# Initialize the encryption key from environment or generate a new one
|
||||
ENCRYPTION_KEY = os.getenv("ENCRYPTION_KEY", Fernet.generate_key().decode())
|
||||
cipher_suite = Fernet(ENCRYPTION_KEY.encode())
|
||||
|
||||
|
||||
class ApiKeyService:
|
||||
"""
|
||||
Service for managing API keys in the database
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def encrypt_api_key(api_key: str) -> str:
|
||||
"""
|
||||
Encrypt an API key
|
||||
"""
|
||||
encrypted_key = cipher_suite.encrypt(api_key.encode())
|
||||
return encrypted_key.decode()
|
||||
|
||||
@staticmethod
|
||||
def decrypt_api_key(encrypted_api_key: str) -> str:
|
||||
"""
|
||||
Decrypt an API key
|
||||
"""
|
||||
decrypted_key = cipher_suite.decrypt(encrypted_api_key.encode())
|
||||
return decrypted_key.decode()
|
||||
|
||||
@staticmethod
|
||||
def get_api_key(db: Session, provider: str) -> str:
|
||||
"""
|
||||
Retrieve and decrypt an API key for a provider
|
||||
"""
|
||||
api_key_record = db.query(ApiKey).filter(ApiKey.provider == provider).first()
|
||||
if not api_key_record or not api_key_record.api_key_encrypted:
|
||||
return None
|
||||
try:
|
||||
return ApiKeyService.decrypt_api_key(api_key_record.api_key_encrypted)
|
||||
except InvalidToken:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def set_api_key(db: Session, provider: str, api_key: str) -> bool:
|
||||
"""
|
||||
Encrypt and store an API key for a provider
|
||||
"""
|
||||
encrypted_key = ApiKeyService.encrypt_api_key(api_key)
|
||||
|
||||
# Check if record exists
|
||||
api_key_record = db.query(ApiKey).filter(ApiKey.provider == provider).first()
|
||||
|
||||
if api_key_record:
|
||||
# Update existing record
|
||||
api_key_record.api_key_encrypted = encrypted_key
|
||||
else:
|
||||
# Create new record
|
||||
api_key_record = ApiKey(
|
||||
provider=provider,
|
||||
api_key_encrypted=encrypted_key
|
||||
)
|
||||
db.add(api_key_record)
|
||||
|
||||
try:
|
||||
db.commit()
|
||||
return True
|
||||
except Exception:
|
||||
db.rollback()
|
||||
return False
|
||||
100
services/config_service.py
Normal file
100
services/config_service.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
import yaml
|
||||
from cryptography.fernet import Fernet, InvalidToken
|
||||
|
||||
CONFIG_PATH = Path(os.getenv("CONFIG_PATH", "/app/config/dialectica.yaml"))
|
||||
|
||||
# Reuse the same encryption key used for API keys
|
||||
_ENCRYPTION_KEY = os.getenv("ENCRYPTION_KEY", "")
|
||||
_cipher = Fernet(_ENCRYPTION_KEY.encode()) if _ENCRYPTION_KEY else None
|
||||
|
||||
# Fields that should be encrypted in the YAML file
|
||||
_SECRET_FIELDS = {"password"}
|
||||
|
||||
|
||||
def _encrypt(value: str) -> str:
|
||||
if not _cipher or not value:
|
||||
return value
|
||||
return "ENC:" + _cipher.encrypt(value.encode()).decode()
|
||||
|
||||
|
||||
def _decrypt(value: str) -> str:
|
||||
if not _cipher or not isinstance(value, str) or not value.startswith("ENC:"):
|
||||
return value
|
||||
try:
|
||||
return _cipher.decrypt(value[4:].encode()).decode()
|
||||
except InvalidToken:
|
||||
return value
|
||||
|
||||
|
||||
def _encrypt_secrets(data: dict) -> dict:
|
||||
"""Deep-copy dict, encrypting secret fields."""
|
||||
out = {}
|
||||
for k, v in data.items():
|
||||
if isinstance(v, dict):
|
||||
out[k] = _encrypt_secrets(v)
|
||||
elif k in _SECRET_FIELDS and isinstance(v, str) and not v.startswith("ENC:"):
|
||||
out[k] = _encrypt(v)
|
||||
else:
|
||||
out[k] = v
|
||||
return out
|
||||
|
||||
|
||||
def _decrypt_secrets(data: dict) -> dict:
|
||||
"""Deep-copy dict, decrypting secret fields."""
|
||||
out = {}
|
||||
for k, v in data.items():
|
||||
if isinstance(v, dict):
|
||||
out[k] = _decrypt_secrets(v)
|
||||
elif k in _SECRET_FIELDS and isinstance(v, str):
|
||||
out[k] = _decrypt(v)
|
||||
else:
|
||||
out[k] = v
|
||||
return out
|
||||
|
||||
|
||||
class ConfigService:
|
||||
"""Read / write config/dialectica.yaml."""
|
||||
|
||||
@staticmethod
|
||||
def load() -> dict:
|
||||
"""Load config, returning decrypted values. Empty dict if file missing."""
|
||||
if not CONFIG_PATH.exists():
|
||||
return {}
|
||||
with open(CONFIG_PATH) as f:
|
||||
raw = yaml.safe_load(f) or {}
|
||||
return _decrypt_secrets(raw)
|
||||
|
||||
@staticmethod
|
||||
def save(config: dict):
|
||||
"""Save config, encrypting secret fields."""
|
||||
CONFIG_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
encrypted = _encrypt_secrets(config)
|
||||
with open(CONFIG_PATH, "w") as f:
|
||||
yaml.dump(encrypted, f, default_flow_style=False, allow_unicode=True)
|
||||
|
||||
@staticmethod
|
||||
def is_db_configured() -> bool:
|
||||
config = ConfigService.load()
|
||||
db = config.get("database", {})
|
||||
return bool(db.get("host") and db.get("database"))
|
||||
|
||||
@staticmethod
|
||||
def get_database_url() -> str | None:
|
||||
config = ConfigService.load()
|
||||
db = config.get("database", {})
|
||||
if not (db.get("host") and db.get("database")):
|
||||
return None
|
||||
user = db.get("user", "root")
|
||||
password = db.get("password", "")
|
||||
host = db["host"]
|
||||
port = db.get("port", 3306)
|
||||
database = db["database"]
|
||||
return f"mysql+pymysql://{quote_plus(user)}:{quote_plus(password)}@{host}:{port}/{database}"
|
||||
|
||||
@staticmethod
|
||||
def is_initialized() -> bool:
|
||||
return ConfigService.load().get("initialized", False)
|
||||
104
services/search_service.py
Normal file
104
services/search_service.py
Normal file
@@ -0,0 +1,104 @@
|
||||
from typing import List, Optional
|
||||
from models.debate import SearchResult, SearchEvidence
|
||||
|
||||
|
||||
class SearchService:
|
||||
"""
|
||||
Tavily web search wrapper for debate research.
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: str):
|
||||
from tavily import TavilyClient
|
||||
self.client = TavilyClient(api_key=api_key)
|
||||
|
||||
def search(self, query: str, max_results: int = 5) -> List[SearchResult]:
|
||||
"""
|
||||
Perform a web search using Tavily and return structured results.
|
||||
"""
|
||||
try:
|
||||
response = self.client.search(
|
||||
query=query,
|
||||
max_results=max_results,
|
||||
search_depth="basic"
|
||||
)
|
||||
results = []
|
||||
for item in response.get("results", []):
|
||||
results.append(SearchResult(
|
||||
title=item.get("title", ""),
|
||||
url=item.get("url", ""),
|
||||
snippet=item.get("content", "")[:500], # Truncate to avoid token bloat
|
||||
score=item.get("score")
|
||||
))
|
||||
return results
|
||||
except Exception as e:
|
||||
print(f"Tavily search error: {e}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def format_results_for_context(evidence: SearchEvidence) -> str:
|
||||
"""
|
||||
Format search results into a string suitable for injecting into the LLM context.
|
||||
"""
|
||||
if not evidence or not evidence.results:
|
||||
return ""
|
||||
lines = [f"\n[网络搜索结果] 搜索词: \"{evidence.query}\""]
|
||||
for i, r in enumerate(evidence.results, 1):
|
||||
lines.append(f" {i}. {r.title}")
|
||||
lines.append(f" {r.snippet}")
|
||||
lines.append(f" 来源: {r.url}")
|
||||
lines.append("[搜索结果结束]\n")
|
||||
return "\n".join(lines)
|
||||
|
||||
@staticmethod
|
||||
def generate_search_query(topic: str, last_opponent_argument: Optional[str] = None) -> str:
|
||||
"""
|
||||
Generate a search query from the debate topic and the opponent's last argument.
|
||||
"""
|
||||
if last_opponent_argument:
|
||||
# Extract key phrases from the opponent's argument (first ~100 chars)
|
||||
snippet = last_opponent_argument[:100].strip()
|
||||
return f"{topic} {snippet}"
|
||||
return topic
|
||||
|
||||
@staticmethod
|
||||
def get_tool_definition() -> dict:
|
||||
"""
|
||||
Return the web_search tool definition for function calling.
|
||||
"""
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "web_search",
|
||||
"description": "Search the web for current information relevant to the debate topic. Use this to find facts, statistics, or recent news that support your argument.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query to look up"
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_tool_definition_anthropic() -> dict:
|
||||
"""
|
||||
Return the web_search tool definition in Anthropic format.
|
||||
"""
|
||||
return {
|
||||
"name": "web_search",
|
||||
"description": "Search the web for current information relevant to the debate topic. Use this to find facts, statistics, or recent news that support your argument.",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query to look up"
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user