feat: add RSA public-key handshake support for monitor server websocket
This commit is contained in:
@@ -23,6 +23,7 @@ from app.services.monitoring import (
|
||||
get_server_states_view,
|
||||
test_provider_connection,
|
||||
)
|
||||
from app.services.crypto_box import get_public_key_info, decrypt_payload_b64, ts_within
|
||||
|
||||
router = APIRouter(prefix='/monitor', tags=['Monitor'])
|
||||
SUPPORTED_PROVIDERS = {'anthropic', 'openai', 'minimax', 'kimi', 'qwen'}
|
||||
@@ -57,6 +58,11 @@ def require_admin(current_user: models.User = Depends(get_current_user_or_apikey
|
||||
return current_user
|
||||
|
||||
|
||||
@router.get('/public/server-public-key')
|
||||
def monitor_public_key():
|
||||
return get_public_key_info()
|
||||
|
||||
|
||||
@router.get('/public/overview')
|
||||
def public_overview(db: Session = Depends(get_db)):
|
||||
return {
|
||||
@@ -202,9 +208,22 @@ async def server_ws(websocket: WebSocket):
|
||||
server_id = None
|
||||
try:
|
||||
hello = await websocket.receive_json()
|
||||
identifier = (hello.get('identifier') or '').strip()
|
||||
challenge_uuid = (hello.get('challenge_uuid') or '').strip()
|
||||
nonce = (hello.get('nonce') or '').strip()
|
||||
|
||||
encrypted_payload = (hello.get('encrypted_payload') or '').strip()
|
||||
if encrypted_payload:
|
||||
data = decrypt_payload_b64(encrypted_payload)
|
||||
identifier = (data.get('identifier') or '').strip()
|
||||
challenge_uuid = (data.get('challenge_uuid') or '').strip()
|
||||
nonce = (data.get('nonce') or '').strip()
|
||||
ts = data.get('ts')
|
||||
if not ts_within(ts, max_minutes=10):
|
||||
await websocket.close(code=4401)
|
||||
return
|
||||
else:
|
||||
# backward compatible mode
|
||||
identifier = (hello.get('identifier') or '').strip()
|
||||
challenge_uuid = (hello.get('challenge_uuid') or '').strip()
|
||||
nonce = (hello.get('nonce') or '').strip()
|
||||
|
||||
if not identifier or not challenge_uuid or not nonce:
|
||||
await websocket.close(code=4400)
|
||||
|
||||
63
app/services/crypto_box.py
Normal file
63
app/services/crypto_box.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any
|
||||
|
||||
from cryptography.hazmat.primitives import serialization, hashes
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa, padding
|
||||
|
||||
KEY_DIR = Path(os.getenv('MONITOR_KEY_DIR', '/config/monitor_keys'))
|
||||
PRIV_PATH = KEY_DIR / 'monitor_private.pem'
|
||||
PUB_PATH = KEY_DIR / 'monitor_public.pem'
|
||||
|
||||
|
||||
def ensure_keypair() -> None:
|
||||
KEY_DIR.mkdir(parents=True, exist_ok=True)
|
||||
if PRIV_PATH.exists() and PUB_PATH.exists():
|
||||
return
|
||||
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
|
||||
private_pem = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
public_pem = private_key.public_key().public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
PRIV_PATH.write_bytes(private_pem)
|
||||
PUB_PATH.write_bytes(public_pem)
|
||||
|
||||
|
||||
def get_public_key_info() -> Dict[str, str]:
|
||||
ensure_keypair()
|
||||
pem = PUB_PATH.read_text()
|
||||
kid = hashlib.sha256(pem.encode()).hexdigest()[:16]
|
||||
return {'public_key_pem': pem, 'key_id': kid}
|
||||
|
||||
|
||||
def decrypt_payload_b64(ciphertext_b64: str) -> Dict[str, Any]:
|
||||
ensure_keypair()
|
||||
private_key = serialization.load_pem_private_key(PRIV_PATH.read_bytes(), password=None)
|
||||
plaintext = private_key.decrypt(
|
||||
base64.b64decode(ciphertext_b64),
|
||||
padding.OAEP(
|
||||
mgf=padding.MGF1(algorithm=hashes.SHA256()),
|
||||
algorithm=hashes.SHA256(),
|
||||
label=None,
|
||||
),
|
||||
)
|
||||
obj = json.loads(plaintext.decode())
|
||||
return obj
|
||||
|
||||
|
||||
def ts_within(ts_iso: str, max_minutes: int = 10) -> bool:
|
||||
try:
|
||||
ts = datetime.fromisoformat(ts_iso.replace('Z', '+00:00'))
|
||||
except Exception:
|
||||
return False
|
||||
now = datetime.now(timezone.utc)
|
||||
return abs((now - ts).total_seconds()) <= max_minutes * 60
|
||||
Reference in New Issue
Block a user