feat: add RSA public-key handshake support for monitor server websocket

This commit is contained in:
zhi
2026-03-11 12:51:54 +00:00
parent 464bccafd8
commit 9fb13f4906
2 changed files with 85 additions and 3 deletions

View File

@@ -0,0 +1,63 @@
import base64
import hashlib
import json
import os
from datetime import datetime, timezone
from pathlib import Path
from typing import Dict, Any
from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.primitives.asymmetric import rsa, padding
KEY_DIR = Path(os.getenv('MONITOR_KEY_DIR', '/config/monitor_keys'))
PRIV_PATH = KEY_DIR / 'monitor_private.pem'
PUB_PATH = KEY_DIR / 'monitor_public.pem'
def ensure_keypair() -> None:
KEY_DIR.mkdir(parents=True, exist_ok=True)
if PRIV_PATH.exists() and PUB_PATH.exists():
return
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
private_pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
)
public_pem = private_key.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
PRIV_PATH.write_bytes(private_pem)
PUB_PATH.write_bytes(public_pem)
def get_public_key_info() -> Dict[str, str]:
ensure_keypair()
pem = PUB_PATH.read_text()
kid = hashlib.sha256(pem.encode()).hexdigest()[:16]
return {'public_key_pem': pem, 'key_id': kid}
def decrypt_payload_b64(ciphertext_b64: str) -> Dict[str, Any]:
ensure_keypair()
private_key = serialization.load_pem_private_key(PRIV_PATH.read_bytes(), password=None)
plaintext = private_key.decrypt(
base64.b64decode(ciphertext_b64),
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None,
),
)
obj = json.loads(plaintext.decode())
return obj
def ts_within(ts_iso: str, max_minutes: int = 10) -> bool:
try:
ts = datetime.fromisoformat(ts_iso.replace('Z', '+00:00'))
except Exception:
return False
now = datetime.now(timezone.utc)
return abs((now - ts).total_seconds()) <= max_minutes * 60