From 464bccafd812a5bac31313d98509599cbb1a6ada Mon Sep 17 00:00:00 2001 From: zhi Date: Wed, 11 Mar 2026 12:41:32 +0000 Subject: [PATCH] feat: add 10m server challenge flow and websocket telemetry channel --- app/api/routers/monitor.py | 113 ++++++++++++++++++++++++++++++++++--- app/models/monitor.py | 20 +++++++ 2 files changed, 126 insertions(+), 7 deletions(-) diff --git a/app/api/routers/monitor.py b/app/api/routers/monitor.py index d4dfbef..6c37a14 100644 --- a/app/api/routers/monitor.py +++ b/app/api/routers/monitor.py @@ -1,15 +1,22 @@ from datetime import datetime, timedelta, timezone import json -from typing import List +import uuid +from typing import List, Dict -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends, HTTPException, status, WebSocket, WebSocketDisconnect from pydantic import BaseModel from sqlalchemy.orm import Session -from app.core.config import get_db +from app.core.config import get_db, SessionLocal from app.api.deps import get_current_user_or_apikey from app.models import models -from app.models.monitor import ProviderAccount, MonitoredServer, ServerState +from app.models.monitor import ( + ProviderAccount, + MonitoredServer, + ServerState, + ServerChallenge, + ServerHandshakeNonce, +) from app.services.monitoring import ( get_issue_stats_cached, get_provider_usage_view, @@ -18,9 +25,8 @@ from app.services.monitoring import ( ) router = APIRouter(prefix='/monitor', tags=['Monitor']) - - SUPPORTED_PROVIDERS = {'anthropic', 'openai', 'minimax', 'kimi', 'qwen'} +ACTIVE_WS: Dict[int, WebSocket] = {} class ProviderAccountCreate(BaseModel): @@ -39,6 +45,12 @@ class MonitoredServerCreate(BaseModel): display_name: str | None = None +class ChallengeResponse(BaseModel): + identifier: str + challenge_uuid: str + expires_at: str + + def require_admin(current_user: models.User = Depends(get_current_user_or_apikey)): if not current_user.is_admin: raise HTTPException(status_code=403, detail='Admin required') @@ -125,6 +137,19 @@ def add_server(payload: MonitoredServerCreate, db: Session = Depends(get_db), us return {'id': obj.id, 'identifier': obj.identifier, 'display_name': obj.display_name, 'is_enabled': obj.is_enabled} +@router.post('/admin/servers/{server_id}/challenge', response_model=ChallengeResponse) +def issue_server_challenge(server_id: int, db: Session = Depends(get_db), _: models.User = Depends(require_admin)): + server = db.query(MonitoredServer).filter(MonitoredServer.id == server_id).first() + if not server: + raise HTTPException(status_code=404, detail='Server not found') + challenge_uuid = str(uuid.uuid4()) + expires_at = datetime.now(timezone.utc) + timedelta(minutes=10) + ch = ServerChallenge(server_id=server_id, challenge_uuid=challenge_uuid, expires_at=expires_at) + db.add(ch) + db.commit() + return ChallengeResponse(identifier=server.identifier, challenge_uuid=challenge_uuid, expires_at=expires_at.isoformat()) + + @router.delete('/admin/servers/{server_id}', status_code=status.HTTP_204_NO_CONTENT) def delete_server(server_id: int, db: Session = Depends(get_db), _: models.User = Depends(require_admin)): obj = db.query(MonitoredServer).filter(MonitoredServer.id == server_id).first() @@ -133,12 +158,13 @@ def delete_server(server_id: int, db: Session = Depends(get_db), _: models.User state = db.query(ServerState).filter(ServerState.server_id == server_id).first() if state: db.delete(state) + db.query(ServerChallenge).filter(ServerChallenge.server_id == server_id).delete() + db.query(ServerHandshakeNonce).filter(ServerHandshakeNonce.server_id == server_id).delete() db.delete(obj) db.commit() return None -# Temporary ingestion endpoint before WS plugin lands class ServerHeartbeat(BaseModel): identifier: str openclaw_version: str | None = None @@ -167,3 +193,76 @@ def server_heartbeat(payload: ServerHeartbeat, db: Session = Depends(get_db)): st.last_seen_at = datetime.now(timezone.utc) db.commit() return {'ok': True, 'server_id': server.id, 'last_seen_at': st.last_seen_at} + + +@router.websocket('/server/ws') +async def server_ws(websocket: WebSocket): + await websocket.accept() + db = SessionLocal() + server_id = None + try: + hello = await websocket.receive_json() + identifier = (hello.get('identifier') or '').strip() + challenge_uuid = (hello.get('challenge_uuid') or '').strip() + nonce = (hello.get('nonce') or '').strip() + + if not identifier or not challenge_uuid or not nonce: + await websocket.close(code=4400) + return + + server = db.query(MonitoredServer).filter(MonitoredServer.identifier == identifier, MonitoredServer.is_enabled == True).first() + if not server: + await websocket.close(code=4404) + return + + ch = db.query(ServerChallenge).filter(ServerChallenge.challenge_uuid == challenge_uuid, ServerChallenge.server_id == server.id).first() + if not ch or ch.used_at is not None or ch.expires_at < datetime.now(timezone.utc): + await websocket.close(code=4401) + return + + nonce_used = db.query(ServerHandshakeNonce).filter(ServerHandshakeNonce.server_id == server.id, ServerHandshakeNonce.nonce == nonce).first() + if nonce_used: + await websocket.close(code=4409) + return + + db.add(ServerHandshakeNonce(server_id=server.id, nonce=nonce)) + ch.used_at = datetime.now(timezone.utc) + db.commit() + + server_id = server.id + ACTIVE_WS[server.id] = websocket + await websocket.send_json({'ok': True, 'server_id': server.id, 'message': 'connected'}) + + while True: + msg = await websocket.receive_json() + event = msg.get('event') + payload = msg.get('payload') or {} + st = db.query(ServerState).filter(ServerState.server_id == server.id).first() + if not st: + st = ServerState(server_id=server.id) + db.add(st) + + if event == 'server.hello': + st.openclaw_version = payload.get('openclaw_version') + st.agents_json = json.dumps(payload.get('agents') or [], ensure_ascii=False) + elif event in {'server.metrics', 'agent.status_changed'}: + st.cpu_pct = payload.get('cpu_pct', st.cpu_pct) + st.mem_pct = payload.get('mem_pct', st.mem_pct) + st.disk_pct = payload.get('disk_pct', st.disk_pct) + st.swap_pct = payload.get('swap_pct', st.swap_pct) + if 'agents' in payload: + st.agents_json = json.dumps(payload.get('agents') or [], ensure_ascii=False) + + st.last_seen_at = datetime.now(timezone.utc) + db.commit() + except WebSocketDisconnect: + pass + except Exception: + try: + await websocket.close(code=1011) + except Exception: + pass + finally: + if server_id and ACTIVE_WS.get(server_id) is websocket: + ACTIVE_WS.pop(server_id, None) + db.close() diff --git a/app/models/monitor.py b/app/models/monitor.py index 030e790..533dbcb 100644 --- a/app/models/monitor.py +++ b/app/models/monitor.py @@ -56,3 +56,23 @@ class ServerState(Base): swap_pct = Column(Float, nullable=True) last_seen_at = Column(DateTime(timezone=True), nullable=True) updated_at = Column(DateTime(timezone=True), onupdate=func.now()) + + +class ServerChallenge(Base): + __tablename__ = 'server_challenges' + + id = Column(Integer, primary_key=True, index=True) + server_id = Column(Integer, ForeignKey('monitored_servers.id'), nullable=False, index=True) + challenge_uuid = Column(String(64), nullable=False, unique=True, index=True) + expires_at = Column(DateTime(timezone=True), nullable=False) + used_at = Column(DateTime(timezone=True), nullable=True) + created_at = Column(DateTime(timezone=True), server_default=func.now()) + + +class ServerHandshakeNonce(Base): + __tablename__ = 'server_handshake_nonces' + + id = Column(Integer, primary_key=True, index=True) + server_id = Column(Integer, ForeignKey('monitored_servers.id'), nullable=False, index=True) + nonce = Column(String(128), nullable=False, index=True) + created_at = Column(DateTime(timezone=True), server_default=func.now())