feat: monitor API key flow and versioned telemetry #10

Merged
hzhang merged 5 commits from feat/monitor-api-key-v2 into main 2026-03-20 09:18:09 +00:00
2 changed files with 4 additions and 145 deletions
Showing only changes of commit 8e0f158266 - Show all commits

View File

@@ -1,22 +1,19 @@
from datetime import datetime, timedelta, timezone
from datetime import datetime, timezone
import json
import secrets
import uuid
from typing import List, Dict
from typing import List
from fastapi import APIRouter, Depends, Header, HTTPException, status, WebSocket, WebSocketDisconnect
from fastapi import APIRouter, Depends, Header, HTTPException, status
from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.core.config import get_db, SessionLocal
from app.core.config import get_db
from app.api.deps import get_current_user_or_apikey
from app.models import models
from app.models.monitor import (
ProviderAccount,
MonitoredServer,
ServerState,
ServerChallenge,
ServerHandshakeNonce,
)
from app.services.monitoring import (
get_task_stats_cached,
@@ -24,11 +21,8 @@ from app.services.monitoring import (
get_server_states_view,
test_provider_connection,
)
from app.services.crypto_box import get_public_key_info, decrypt_payload_b64, ts_within
router = APIRouter(prefix='/monitor', tags=['Monitor'])
SUPPORTED_PROVIDERS = {'anthropic', 'openai', 'minimax', 'kimi', 'qwen'}
ACTIVE_WS: Dict[int, WebSocket] = {}
class ProviderAccountCreate(BaseModel):
@@ -47,23 +41,12 @@ class MonitoredServerCreate(BaseModel):
display_name: str | None = None
class ChallengeResponse(BaseModel):
identifier: str
challenge_uuid: str
expires_at: str
def require_admin(current_user: models.User = Depends(get_current_user_or_apikey)):
if not current_user.is_admin:
raise HTTPException(status_code=403, detail='Admin required')
return current_user
@router.get('/public/server-public-key')
def monitor_public_key():
return get_public_key_info()
@router.get('/public/overview')
def public_overview(db: Session = Depends(get_db)):
return {
@@ -144,19 +127,6 @@ def add_server(payload: MonitoredServerCreate, db: Session = Depends(get_db), us
return {'id': obj.id, 'identifier': obj.identifier, 'display_name': obj.display_name, 'is_enabled': obj.is_enabled}
@router.post('/admin/servers/{server_id}/challenge', response_model=ChallengeResponse)
def issue_server_challenge(server_id: int, db: Session = Depends(get_db), _: models.User = Depends(require_admin)):
server = db.query(MonitoredServer).filter(MonitoredServer.id == server_id).first()
if not server:
raise HTTPException(status_code=404, detail='Server not found')
challenge_uuid = str(uuid.uuid4())
expires_at = datetime.now(timezone.utc) + timedelta(minutes=10)
ch = ServerChallenge(server_id=server_id, challenge_uuid=challenge_uuid, expires_at=expires_at)
db.add(ch)
db.commit()
return ChallengeResponse(identifier=server.identifier, challenge_uuid=challenge_uuid, expires_at=expires_at.isoformat())
@router.delete('/admin/servers/{server_id}', status_code=status.HTTP_204_NO_CONTENT)
def delete_server(server_id: int, db: Session = Depends(get_db), _: models.User = Depends(require_admin)):
obj = db.query(MonitoredServer).filter(MonitoredServer.id == server_id).first()
@@ -165,8 +135,6 @@ def delete_server(server_id: int, db: Session = Depends(get_db), _: models.User
state = db.query(ServerState).filter(ServerState.server_id == server_id).first()
if state:
db.delete(state)
db.query(ServerChallenge).filter(ServerChallenge.server_id == server_id).delete()
db.query(ServerHandshakeNonce).filter(ServerHandshakeNonce.server_id == server_id).delete()
db.delete(obj)
db.commit()
return None
@@ -269,93 +237,3 @@ def server_heartbeat_v2(
db.commit()
return {'ok': True, 'server_id': server.id, 'identifier': server.identifier, 'last_seen_at': st.last_seen_at}
@router.websocket('/server/ws')
async def server_ws(websocket: WebSocket):
await websocket.accept()
db = SessionLocal()
server_id = None
try:
hello = await websocket.receive_json()
encrypted_payload = (hello.get('encrypted_payload') or '').strip()
if encrypted_payload:
data = decrypt_payload_b64(encrypted_payload)
identifier = (data.get('identifier') or '').strip()
challenge_uuid = (data.get('challenge_uuid') or '').strip()
nonce = (data.get('nonce') or '').strip()
ts = data.get('ts')
if not ts_within(ts, max_minutes=10):
await websocket.close(code=4401)
return
else:
# backward compatible mode
identifier = (hello.get('identifier') or '').strip()
challenge_uuid = (hello.get('challenge_uuid') or '').strip()
nonce = (hello.get('nonce') or '').strip()
if not identifier or not challenge_uuid or not nonce:
await websocket.close(code=4400)
return
server = db.query(MonitoredServer).filter(MonitoredServer.identifier == identifier, MonitoredServer.is_enabled == True).first()
if not server:
await websocket.close(code=4404)
return
ch = db.query(ServerChallenge).filter(ServerChallenge.challenge_uuid == challenge_uuid, ServerChallenge.server_id == server.id).first()
if not ch or ch.used_at is not None or ch.expires_at < datetime.now(timezone.utc):
await websocket.close(code=4401)
return
nonce_used = db.query(ServerHandshakeNonce).filter(ServerHandshakeNonce.server_id == server.id, ServerHandshakeNonce.nonce == nonce).first()
if nonce_used:
await websocket.close(code=4409)
return
db.add(ServerHandshakeNonce(server_id=server.id, nonce=nonce))
ch.used_at = datetime.now(timezone.utc)
db.commit()
server_id = server.id
ACTIVE_WS[server.id] = websocket
await websocket.send_json({'ok': True, 'server_id': server.id, 'message': 'connected'})
while True:
msg = await websocket.receive_json()
event = msg.get('event')
payload = msg.get('payload') or {}
st = db.query(ServerState).filter(ServerState.server_id == server.id).first()
if not st:
st = ServerState(server_id=server.id)
db.add(st)
if event == 'server.hello':
st.openclaw_version = payload.get('openclaw_version')
st.plugin_version = payload.get('plugin_version')
st.agents_json = json.dumps(payload.get('agents') or [], ensure_ascii=False)
elif event in {'server.metrics', 'agent.status_changed'}:
st.cpu_pct = payload.get('cpu_pct', st.cpu_pct)
st.mem_pct = payload.get('mem_pct', st.mem_pct)
st.disk_pct = payload.get('disk_pct', st.disk_pct)
st.swap_pct = payload.get('swap_pct', st.swap_pct)
if 'openclaw_version' in payload:
st.openclaw_version = payload.get('openclaw_version')
if 'plugin_version' in payload:
st.plugin_version = payload.get('plugin_version')
if 'agents' in payload:
st.agents_json = json.dumps(payload.get('agents') or [], ensure_ascii=False)
st.last_seen_at = datetime.now(timezone.utc)
db.commit()
except WebSocketDisconnect:
pass
except Exception:
try:
await websocket.close(code=1011)
except Exception:
pass
finally:
if server_id and ACTIVE_WS.get(server_id) is websocket:
ACTIVE_WS.pop(server_id, None)
db.close()

View File

@@ -59,22 +59,3 @@ class ServerState(Base):
last_seen_at = Column(DateTime(timezone=True), nullable=True)
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
class ServerChallenge(Base):
__tablename__ = 'server_challenges'
id = Column(Integer, primary_key=True, index=True)
server_id = Column(Integer, ForeignKey('monitored_servers.id'), nullable=False, index=True)
challenge_uuid = Column(String(64), nullable=False, unique=True, index=True)
expires_at = Column(DateTime(timezone=True), nullable=False)
used_at = Column(DateTime(timezone=True), nullable=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())
class ServerHandshakeNonce(Base):
__tablename__ = 'server_handshake_nonces'
id = Column(Integer, primary_key=True, index=True)
server_id = Column(Integer, ForeignKey('monitored_servers.id'), nullable=False, index=True)
nonce = Column(String(128), nullable=False, index=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())