feat: add RSA public-key handshake support for monitor server websocket
This commit is contained in:
@@ -23,6 +23,7 @@ from app.services.monitoring import (
|
|||||||
get_server_states_view,
|
get_server_states_view,
|
||||||
test_provider_connection,
|
test_provider_connection,
|
||||||
)
|
)
|
||||||
|
from app.services.crypto_box import get_public_key_info, decrypt_payload_b64, ts_within
|
||||||
|
|
||||||
router = APIRouter(prefix='/monitor', tags=['Monitor'])
|
router = APIRouter(prefix='/monitor', tags=['Monitor'])
|
||||||
SUPPORTED_PROVIDERS = {'anthropic', 'openai', 'minimax', 'kimi', 'qwen'}
|
SUPPORTED_PROVIDERS = {'anthropic', 'openai', 'minimax', 'kimi', 'qwen'}
|
||||||
@@ -57,6 +58,11 @@ def require_admin(current_user: models.User = Depends(get_current_user_or_apikey
|
|||||||
return current_user
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
|
@router.get('/public/server-public-key')
|
||||||
|
def monitor_public_key():
|
||||||
|
return get_public_key_info()
|
||||||
|
|
||||||
|
|
||||||
@router.get('/public/overview')
|
@router.get('/public/overview')
|
||||||
def public_overview(db: Session = Depends(get_db)):
|
def public_overview(db: Session = Depends(get_db)):
|
||||||
return {
|
return {
|
||||||
@@ -202,9 +208,22 @@ async def server_ws(websocket: WebSocket):
|
|||||||
server_id = None
|
server_id = None
|
||||||
try:
|
try:
|
||||||
hello = await websocket.receive_json()
|
hello = await websocket.receive_json()
|
||||||
identifier = (hello.get('identifier') or '').strip()
|
|
||||||
challenge_uuid = (hello.get('challenge_uuid') or '').strip()
|
encrypted_payload = (hello.get('encrypted_payload') or '').strip()
|
||||||
nonce = (hello.get('nonce') or '').strip()
|
if encrypted_payload:
|
||||||
|
data = decrypt_payload_b64(encrypted_payload)
|
||||||
|
identifier = (data.get('identifier') or '').strip()
|
||||||
|
challenge_uuid = (data.get('challenge_uuid') or '').strip()
|
||||||
|
nonce = (data.get('nonce') or '').strip()
|
||||||
|
ts = data.get('ts')
|
||||||
|
if not ts_within(ts, max_minutes=10):
|
||||||
|
await websocket.close(code=4401)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
# backward compatible mode
|
||||||
|
identifier = (hello.get('identifier') or '').strip()
|
||||||
|
challenge_uuid = (hello.get('challenge_uuid') or '').strip()
|
||||||
|
nonce = (hello.get('nonce') or '').strip()
|
||||||
|
|
||||||
if not identifier or not challenge_uuid or not nonce:
|
if not identifier or not challenge_uuid or not nonce:
|
||||||
await websocket.close(code=4400)
|
await websocket.close(code=4400)
|
||||||
|
|||||||
63
app/services/crypto_box.py
Normal file
63
app/services/crypto_box.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
from cryptography.hazmat.primitives import serialization, hashes
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import rsa, padding
|
||||||
|
|
||||||
|
KEY_DIR = Path(os.getenv('MONITOR_KEY_DIR', '/config/monitor_keys'))
|
||||||
|
PRIV_PATH = KEY_DIR / 'monitor_private.pem'
|
||||||
|
PUB_PATH = KEY_DIR / 'monitor_public.pem'
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_keypair() -> None:
|
||||||
|
KEY_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
if PRIV_PATH.exists() and PUB_PATH.exists():
|
||||||
|
return
|
||||||
|
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
|
||||||
|
private_pem = private_key.private_bytes(
|
||||||
|
encoding=serialization.Encoding.PEM,
|
||||||
|
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
||||||
|
encryption_algorithm=serialization.NoEncryption(),
|
||||||
|
)
|
||||||
|
public_pem = private_key.public_key().public_bytes(
|
||||||
|
encoding=serialization.Encoding.PEM,
|
||||||
|
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||||
|
)
|
||||||
|
PRIV_PATH.write_bytes(private_pem)
|
||||||
|
PUB_PATH.write_bytes(public_pem)
|
||||||
|
|
||||||
|
|
||||||
|
def get_public_key_info() -> Dict[str, str]:
|
||||||
|
ensure_keypair()
|
||||||
|
pem = PUB_PATH.read_text()
|
||||||
|
kid = hashlib.sha256(pem.encode()).hexdigest()[:16]
|
||||||
|
return {'public_key_pem': pem, 'key_id': kid}
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt_payload_b64(ciphertext_b64: str) -> Dict[str, Any]:
|
||||||
|
ensure_keypair()
|
||||||
|
private_key = serialization.load_pem_private_key(PRIV_PATH.read_bytes(), password=None)
|
||||||
|
plaintext = private_key.decrypt(
|
||||||
|
base64.b64decode(ciphertext_b64),
|
||||||
|
padding.OAEP(
|
||||||
|
mgf=padding.MGF1(algorithm=hashes.SHA256()),
|
||||||
|
algorithm=hashes.SHA256(),
|
||||||
|
label=None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
obj = json.loads(plaintext.decode())
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
def ts_within(ts_iso: str, max_minutes: int = 10) -> bool:
|
||||||
|
try:
|
||||||
|
ts = datetime.fromisoformat(ts_iso.replace('Z', '+00:00'))
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
return abs((now - ts).total_seconds()) <= max_minutes * 60
|
||||||
Reference in New Issue
Block a user