diff --git a/app/api/routers/monitor.py b/app/api/routers/monitor.py index 6c37a14..1435e6f 100644 --- a/app/api/routers/monitor.py +++ b/app/api/routers/monitor.py @@ -23,6 +23,7 @@ from app.services.monitoring import ( get_server_states_view, test_provider_connection, ) +from app.services.crypto_box import get_public_key_info, decrypt_payload_b64, ts_within router = APIRouter(prefix='/monitor', tags=['Monitor']) SUPPORTED_PROVIDERS = {'anthropic', 'openai', 'minimax', 'kimi', 'qwen'} @@ -57,6 +58,11 @@ def require_admin(current_user: models.User = Depends(get_current_user_or_apikey return current_user +@router.get('/public/server-public-key') +def monitor_public_key(): + return get_public_key_info() + + @router.get('/public/overview') def public_overview(db: Session = Depends(get_db)): return { @@ -202,9 +208,22 @@ async def server_ws(websocket: WebSocket): server_id = None try: hello = await websocket.receive_json() - identifier = (hello.get('identifier') or '').strip() - challenge_uuid = (hello.get('challenge_uuid') or '').strip() - nonce = (hello.get('nonce') or '').strip() + + encrypted_payload = (hello.get('encrypted_payload') or '').strip() + if encrypted_payload: + data = decrypt_payload_b64(encrypted_payload) + identifier = (data.get('identifier') or '').strip() + challenge_uuid = (data.get('challenge_uuid') or '').strip() + nonce = (data.get('nonce') or '').strip() + ts = data.get('ts') + if not ts_within(ts, max_minutes=10): + await websocket.close(code=4401) + return + else: + # backward compatible mode + identifier = (hello.get('identifier') or '').strip() + challenge_uuid = (hello.get('challenge_uuid') or '').strip() + nonce = (hello.get('nonce') or '').strip() if not identifier or not challenge_uuid or not nonce: await websocket.close(code=4400) diff --git a/app/services/crypto_box.py b/app/services/crypto_box.py new file mode 100644 index 0000000..38423ba --- /dev/null +++ b/app/services/crypto_box.py @@ -0,0 +1,63 @@ +import base64 +import hashlib +import json +import os +from datetime import datetime, timezone +from pathlib import Path +from typing import Dict, Any + +from cryptography.hazmat.primitives import serialization, hashes +from cryptography.hazmat.primitives.asymmetric import rsa, padding + +KEY_DIR = Path(os.getenv('MONITOR_KEY_DIR', '/config/monitor_keys')) +PRIV_PATH = KEY_DIR / 'monitor_private.pem' +PUB_PATH = KEY_DIR / 'monitor_public.pem' + + +def ensure_keypair() -> None: + KEY_DIR.mkdir(parents=True, exist_ok=True) + if PRIV_PATH.exists() and PUB_PATH.exists(): + return + private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + private_pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + public_pem = private_key.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + PRIV_PATH.write_bytes(private_pem) + PUB_PATH.write_bytes(public_pem) + + +def get_public_key_info() -> Dict[str, str]: + ensure_keypair() + pem = PUB_PATH.read_text() + kid = hashlib.sha256(pem.encode()).hexdigest()[:16] + return {'public_key_pem': pem, 'key_id': kid} + + +def decrypt_payload_b64(ciphertext_b64: str) -> Dict[str, Any]: + ensure_keypair() + private_key = serialization.load_pem_private_key(PRIV_PATH.read_bytes(), password=None) + plaintext = private_key.decrypt( + base64.b64decode(ciphertext_b64), + padding.OAEP( + mgf=padding.MGF1(algorithm=hashes.SHA256()), + algorithm=hashes.SHA256(), + label=None, + ), + ) + obj = json.loads(plaintext.decode()) + return obj + + +def ts_within(ts_iso: str, max_minutes: int = 10) -> bool: + try: + ts = datetime.fromisoformat(ts_iso.replace('Z', '+00:00')) + except Exception: + return False + now = datetime.now(timezone.utc) + return abs((now - ts).total_seconds()) <= max_minutes * 60