feat: add RSA public-key handshake support for monitor server websocket

This commit is contained in:
zhi
2026-03-11 12:51:54 +00:00
parent 464bccafd8
commit 9fb13f4906
2 changed files with 85 additions and 3 deletions

View File

@@ -23,6 +23,7 @@ from app.services.monitoring import (
get_server_states_view,
test_provider_connection,
)
from app.services.crypto_box import get_public_key_info, decrypt_payload_b64, ts_within
router = APIRouter(prefix='/monitor', tags=['Monitor'])
SUPPORTED_PROVIDERS = {'anthropic', 'openai', 'minimax', 'kimi', 'qwen'}
@@ -57,6 +58,11 @@ def require_admin(current_user: models.User = Depends(get_current_user_or_apikey
return current_user
@router.get('/public/server-public-key')
def monitor_public_key():
return get_public_key_info()
@router.get('/public/overview')
def public_overview(db: Session = Depends(get_db)):
return {
@@ -202,9 +208,22 @@ async def server_ws(websocket: WebSocket):
server_id = None
try:
hello = await websocket.receive_json()
identifier = (hello.get('identifier') or '').strip()
challenge_uuid = (hello.get('challenge_uuid') or '').strip()
nonce = (hello.get('nonce') or '').strip()
encrypted_payload = (hello.get('encrypted_payload') or '').strip()
if encrypted_payload:
data = decrypt_payload_b64(encrypted_payload)
identifier = (data.get('identifier') or '').strip()
challenge_uuid = (data.get('challenge_uuid') or '').strip()
nonce = (data.get('nonce') or '').strip()
ts = data.get('ts')
if not ts_within(ts, max_minutes=10):
await websocket.close(code=4401)
return
else:
# backward compatible mode
identifier = (hello.get('identifier') or '').strip()
challenge_uuid = (hello.get('challenge_uuid') or '').strip()
nonce = (hello.get('nonce') or '').strip()
if not identifier or not challenge_uuid or not nonce:
await websocket.close(code=4400)

View File

@@ -0,0 +1,63 @@
import base64
import hashlib
import json
import os
from datetime import datetime, timezone
from pathlib import Path
from typing import Dict, Any
from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.primitives.asymmetric import rsa, padding
KEY_DIR = Path(os.getenv('MONITOR_KEY_DIR', '/config/monitor_keys'))
PRIV_PATH = KEY_DIR / 'monitor_private.pem'
PUB_PATH = KEY_DIR / 'monitor_public.pem'
def ensure_keypair() -> None:
KEY_DIR.mkdir(parents=True, exist_ok=True)
if PRIV_PATH.exists() and PUB_PATH.exists():
return
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
private_pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
)
public_pem = private_key.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
PRIV_PATH.write_bytes(private_pem)
PUB_PATH.write_bytes(public_pem)
def get_public_key_info() -> Dict[str, str]:
ensure_keypair()
pem = PUB_PATH.read_text()
kid = hashlib.sha256(pem.encode()).hexdigest()[:16]
return {'public_key_pem': pem, 'key_id': kid}
def decrypt_payload_b64(ciphertext_b64: str) -> Dict[str, Any]:
ensure_keypair()
private_key = serialization.load_pem_private_key(PRIV_PATH.read_bytes(), password=None)
plaintext = private_key.decrypt(
base64.b64decode(ciphertext_b64),
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None,
),
)
obj = json.loads(plaintext.decode())
return obj
def ts_within(ts_iso: str, max_minutes: int = 10) -> bool:
try:
ts = datetime.fromisoformat(ts_iso.replace('Z', '+00:00'))
except Exception:
return False
now = datetime.now(timezone.utc)
return abs((now - ts).total_seconds()) <= max_minutes * 60