feat: add 10m server challenge flow and websocket telemetry channel

This commit is contained in:
zhi
2026-03-11 12:41:32 +00:00
parent d299428d35
commit 464bccafd8
2 changed files with 126 additions and 7 deletions

View File

@@ -1,15 +1,22 @@
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
import json import json
from typing import List import uuid
from typing import List, Dict
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status, WebSocket, WebSocketDisconnect
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.config import get_db from app.core.config import get_db, SessionLocal
from app.api.deps import get_current_user_or_apikey from app.api.deps import get_current_user_or_apikey
from app.models import models from app.models import models
from app.models.monitor import ProviderAccount, MonitoredServer, ServerState from app.models.monitor import (
ProviderAccount,
MonitoredServer,
ServerState,
ServerChallenge,
ServerHandshakeNonce,
)
from app.services.monitoring import ( from app.services.monitoring import (
get_issue_stats_cached, get_issue_stats_cached,
get_provider_usage_view, get_provider_usage_view,
@@ -18,9 +25,8 @@ from app.services.monitoring import (
) )
router = APIRouter(prefix='/monitor', tags=['Monitor']) router = APIRouter(prefix='/monitor', tags=['Monitor'])
SUPPORTED_PROVIDERS = {'anthropic', 'openai', 'minimax', 'kimi', 'qwen'} SUPPORTED_PROVIDERS = {'anthropic', 'openai', 'minimax', 'kimi', 'qwen'}
ACTIVE_WS: Dict[int, WebSocket] = {}
class ProviderAccountCreate(BaseModel): class ProviderAccountCreate(BaseModel):
@@ -39,6 +45,12 @@ class MonitoredServerCreate(BaseModel):
display_name: str | None = None display_name: str | None = None
class ChallengeResponse(BaseModel):
identifier: str
challenge_uuid: str
expires_at: str
def require_admin(current_user: models.User = Depends(get_current_user_or_apikey)): def require_admin(current_user: models.User = Depends(get_current_user_or_apikey)):
if not current_user.is_admin: if not current_user.is_admin:
raise HTTPException(status_code=403, detail='Admin required') raise HTTPException(status_code=403, detail='Admin required')
@@ -125,6 +137,19 @@ def add_server(payload: MonitoredServerCreate, db: Session = Depends(get_db), us
return {'id': obj.id, 'identifier': obj.identifier, 'display_name': obj.display_name, 'is_enabled': obj.is_enabled} return {'id': obj.id, 'identifier': obj.identifier, 'display_name': obj.display_name, 'is_enabled': obj.is_enabled}
@router.post('/admin/servers/{server_id}/challenge', response_model=ChallengeResponse)
def issue_server_challenge(server_id: int, db: Session = Depends(get_db), _: models.User = Depends(require_admin)):
server = db.query(MonitoredServer).filter(MonitoredServer.id == server_id).first()
if not server:
raise HTTPException(status_code=404, detail='Server not found')
challenge_uuid = str(uuid.uuid4())
expires_at = datetime.now(timezone.utc) + timedelta(minutes=10)
ch = ServerChallenge(server_id=server_id, challenge_uuid=challenge_uuid, expires_at=expires_at)
db.add(ch)
db.commit()
return ChallengeResponse(identifier=server.identifier, challenge_uuid=challenge_uuid, expires_at=expires_at.isoformat())
@router.delete('/admin/servers/{server_id}', status_code=status.HTTP_204_NO_CONTENT) @router.delete('/admin/servers/{server_id}', status_code=status.HTTP_204_NO_CONTENT)
def delete_server(server_id: int, db: Session = Depends(get_db), _: models.User = Depends(require_admin)): def delete_server(server_id: int, db: Session = Depends(get_db), _: models.User = Depends(require_admin)):
obj = db.query(MonitoredServer).filter(MonitoredServer.id == server_id).first() obj = db.query(MonitoredServer).filter(MonitoredServer.id == server_id).first()
@@ -133,12 +158,13 @@ def delete_server(server_id: int, db: Session = Depends(get_db), _: models.User
state = db.query(ServerState).filter(ServerState.server_id == server_id).first() state = db.query(ServerState).filter(ServerState.server_id == server_id).first()
if state: if state:
db.delete(state) db.delete(state)
db.query(ServerChallenge).filter(ServerChallenge.server_id == server_id).delete()
db.query(ServerHandshakeNonce).filter(ServerHandshakeNonce.server_id == server_id).delete()
db.delete(obj) db.delete(obj)
db.commit() db.commit()
return None return None
# Temporary ingestion endpoint before WS plugin lands
class ServerHeartbeat(BaseModel): class ServerHeartbeat(BaseModel):
identifier: str identifier: str
openclaw_version: str | None = None openclaw_version: str | None = None
@@ -167,3 +193,76 @@ def server_heartbeat(payload: ServerHeartbeat, db: Session = Depends(get_db)):
st.last_seen_at = datetime.now(timezone.utc) st.last_seen_at = datetime.now(timezone.utc)
db.commit() db.commit()
return {'ok': True, 'server_id': server.id, 'last_seen_at': st.last_seen_at} return {'ok': True, 'server_id': server.id, 'last_seen_at': st.last_seen_at}
@router.websocket('/server/ws')
async def server_ws(websocket: WebSocket):
await websocket.accept()
db = SessionLocal()
server_id = None
try:
hello = await websocket.receive_json()
identifier = (hello.get('identifier') or '').strip()
challenge_uuid = (hello.get('challenge_uuid') or '').strip()
nonce = (hello.get('nonce') or '').strip()
if not identifier or not challenge_uuid or not nonce:
await websocket.close(code=4400)
return
server = db.query(MonitoredServer).filter(MonitoredServer.identifier == identifier, MonitoredServer.is_enabled == True).first()
if not server:
await websocket.close(code=4404)
return
ch = db.query(ServerChallenge).filter(ServerChallenge.challenge_uuid == challenge_uuid, ServerChallenge.server_id == server.id).first()
if not ch or ch.used_at is not None or ch.expires_at < datetime.now(timezone.utc):
await websocket.close(code=4401)
return
nonce_used = db.query(ServerHandshakeNonce).filter(ServerHandshakeNonce.server_id == server.id, ServerHandshakeNonce.nonce == nonce).first()
if nonce_used:
await websocket.close(code=4409)
return
db.add(ServerHandshakeNonce(server_id=server.id, nonce=nonce))
ch.used_at = datetime.now(timezone.utc)
db.commit()
server_id = server.id
ACTIVE_WS[server.id] = websocket
await websocket.send_json({'ok': True, 'server_id': server.id, 'message': 'connected'})
while True:
msg = await websocket.receive_json()
event = msg.get('event')
payload = msg.get('payload') or {}
st = db.query(ServerState).filter(ServerState.server_id == server.id).first()
if not st:
st = ServerState(server_id=server.id)
db.add(st)
if event == 'server.hello':
st.openclaw_version = payload.get('openclaw_version')
st.agents_json = json.dumps(payload.get('agents') or [], ensure_ascii=False)
elif event in {'server.metrics', 'agent.status_changed'}:
st.cpu_pct = payload.get('cpu_pct', st.cpu_pct)
st.mem_pct = payload.get('mem_pct', st.mem_pct)
st.disk_pct = payload.get('disk_pct', st.disk_pct)
st.swap_pct = payload.get('swap_pct', st.swap_pct)
if 'agents' in payload:
st.agents_json = json.dumps(payload.get('agents') or [], ensure_ascii=False)
st.last_seen_at = datetime.now(timezone.utc)
db.commit()
except WebSocketDisconnect:
pass
except Exception:
try:
await websocket.close(code=1011)
except Exception:
pass
finally:
if server_id and ACTIVE_WS.get(server_id) is websocket:
ACTIVE_WS.pop(server_id, None)
db.close()

View File

@@ -56,3 +56,23 @@ class ServerState(Base):
swap_pct = Column(Float, nullable=True) swap_pct = Column(Float, nullable=True)
last_seen_at = Column(DateTime(timezone=True), nullable=True) last_seen_at = Column(DateTime(timezone=True), nullable=True)
updated_at = Column(DateTime(timezone=True), onupdate=func.now()) updated_at = Column(DateTime(timezone=True), onupdate=func.now())
class ServerChallenge(Base):
__tablename__ = 'server_challenges'
id = Column(Integer, primary_key=True, index=True)
server_id = Column(Integer, ForeignKey('monitored_servers.id'), nullable=False, index=True)
challenge_uuid = Column(String(64), nullable=False, unique=True, index=True)
expires_at = Column(DateTime(timezone=True), nullable=False)
used_at = Column(DateTime(timezone=True), nullable=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())
class ServerHandshakeNonce(Base):
__tablename__ = 'server_handshake_nonces'
id = Column(Integer, primary_key=True, index=True)
server_id = Column(Integer, ForeignKey('monitored_servers.id'), nullable=False, index=True)
nonce = Column(String(128), nullable=False, index=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())