Files
Dialectic.Backend/services/config_service.py
2026-02-12 15:45:48 +00:00

101 lines
3.0 KiB
Python

import os
from pathlib import Path
from urllib.parse import quote_plus
import yaml
from cryptography.fernet import Fernet, InvalidToken
CONFIG_PATH = Path(os.getenv("CONFIG_PATH", "/app/config/dialectica.yaml"))
# Reuse the same encryption key used for API keys
_ENCRYPTION_KEY = os.getenv("ENCRYPTION_KEY", "")
_cipher = Fernet(_ENCRYPTION_KEY.encode()) if _ENCRYPTION_KEY else None
# Fields that should be encrypted in the YAML file
_SECRET_FIELDS = {"password"}
def _encrypt(value: str) -> str:
if not _cipher or not value:
return value
return "ENC:" + _cipher.encrypt(value.encode()).decode()
def _decrypt(value: str) -> str:
if not _cipher or not isinstance(value, str) or not value.startswith("ENC:"):
return value
try:
return _cipher.decrypt(value[4:].encode()).decode()
except InvalidToken:
return value
def _encrypt_secrets(data: dict) -> dict:
"""Deep-copy dict, encrypting secret fields."""
out = {}
for k, v in data.items():
if isinstance(v, dict):
out[k] = _encrypt_secrets(v)
elif k in _SECRET_FIELDS and isinstance(v, str) and not v.startswith("ENC:"):
out[k] = _encrypt(v)
else:
out[k] = v
return out
def _decrypt_secrets(data: dict) -> dict:
"""Deep-copy dict, decrypting secret fields."""
out = {}
for k, v in data.items():
if isinstance(v, dict):
out[k] = _decrypt_secrets(v)
elif k in _SECRET_FIELDS and isinstance(v, str):
out[k] = _decrypt(v)
else:
out[k] = v
return out
class ConfigService:
"""Read / write config/dialectica.yaml."""
@staticmethod
def load() -> dict:
"""Load config, returning decrypted values. Empty dict if file missing."""
if not CONFIG_PATH.exists():
return {}
with open(CONFIG_PATH) as f:
raw = yaml.safe_load(f) or {}
return _decrypt_secrets(raw)
@staticmethod
def save(config: dict):
"""Save config, encrypting secret fields."""
CONFIG_PATH.parent.mkdir(parents=True, exist_ok=True)
encrypted = _encrypt_secrets(config)
with open(CONFIG_PATH, "w") as f:
yaml.dump(encrypted, f, default_flow_style=False, allow_unicode=True)
@staticmethod
def is_db_configured() -> bool:
config = ConfigService.load()
db = config.get("database", {})
return bool(db.get("host") and db.get("database"))
@staticmethod
def get_database_url() -> str | None:
config = ConfigService.load()
db = config.get("database", {})
if not (db.get("host") and db.get("database")):
return None
user = db.get("user", "root")
password = db.get("password", "")
host = db["host"]
port = db.get("port", 3306)
database = db["database"]
return f"mysql+pymysql://{quote_plus(user)}:{quote_plus(password)}@{host}:{port}/{database}"
@staticmethod
def is_initialized() -> bool:
return ConfigService.load().get("initialized", False)