feat: add 10m server challenge flow and websocket telemetry channel
This commit is contained in:
@@ -1,15 +1,22 @@
|
|||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
import json
|
import json
|
||||||
from typing import List
|
import uuid
|
||||||
|
from typing import List, Dict
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status, WebSocket, WebSocketDisconnect
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.config import get_db
|
from app.core.config import get_db, SessionLocal
|
||||||
from app.api.deps import get_current_user_or_apikey
|
from app.api.deps import get_current_user_or_apikey
|
||||||
from app.models import models
|
from app.models import models
|
||||||
from app.models.monitor import ProviderAccount, MonitoredServer, ServerState
|
from app.models.monitor import (
|
||||||
|
ProviderAccount,
|
||||||
|
MonitoredServer,
|
||||||
|
ServerState,
|
||||||
|
ServerChallenge,
|
||||||
|
ServerHandshakeNonce,
|
||||||
|
)
|
||||||
from app.services.monitoring import (
|
from app.services.monitoring import (
|
||||||
get_issue_stats_cached,
|
get_issue_stats_cached,
|
||||||
get_provider_usage_view,
|
get_provider_usage_view,
|
||||||
@@ -18,9 +25,8 @@ from app.services.monitoring import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
router = APIRouter(prefix='/monitor', tags=['Monitor'])
|
router = APIRouter(prefix='/monitor', tags=['Monitor'])
|
||||||
|
|
||||||
|
|
||||||
SUPPORTED_PROVIDERS = {'anthropic', 'openai', 'minimax', 'kimi', 'qwen'}
|
SUPPORTED_PROVIDERS = {'anthropic', 'openai', 'minimax', 'kimi', 'qwen'}
|
||||||
|
ACTIVE_WS: Dict[int, WebSocket] = {}
|
||||||
|
|
||||||
|
|
||||||
class ProviderAccountCreate(BaseModel):
|
class ProviderAccountCreate(BaseModel):
|
||||||
@@ -39,6 +45,12 @@ class MonitoredServerCreate(BaseModel):
|
|||||||
display_name: str | None = None
|
display_name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChallengeResponse(BaseModel):
|
||||||
|
identifier: str
|
||||||
|
challenge_uuid: str
|
||||||
|
expires_at: str
|
||||||
|
|
||||||
|
|
||||||
def require_admin(current_user: models.User = Depends(get_current_user_or_apikey)):
|
def require_admin(current_user: models.User = Depends(get_current_user_or_apikey)):
|
||||||
if not current_user.is_admin:
|
if not current_user.is_admin:
|
||||||
raise HTTPException(status_code=403, detail='Admin required')
|
raise HTTPException(status_code=403, detail='Admin required')
|
||||||
@@ -125,6 +137,19 @@ def add_server(payload: MonitoredServerCreate, db: Session = Depends(get_db), us
|
|||||||
return {'id': obj.id, 'identifier': obj.identifier, 'display_name': obj.display_name, 'is_enabled': obj.is_enabled}
|
return {'id': obj.id, 'identifier': obj.identifier, 'display_name': obj.display_name, 'is_enabled': obj.is_enabled}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post('/admin/servers/{server_id}/challenge', response_model=ChallengeResponse)
|
||||||
|
def issue_server_challenge(server_id: int, db: Session = Depends(get_db), _: models.User = Depends(require_admin)):
|
||||||
|
server = db.query(MonitoredServer).filter(MonitoredServer.id == server_id).first()
|
||||||
|
if not server:
|
||||||
|
raise HTTPException(status_code=404, detail='Server not found')
|
||||||
|
challenge_uuid = str(uuid.uuid4())
|
||||||
|
expires_at = datetime.now(timezone.utc) + timedelta(minutes=10)
|
||||||
|
ch = ServerChallenge(server_id=server_id, challenge_uuid=challenge_uuid, expires_at=expires_at)
|
||||||
|
db.add(ch)
|
||||||
|
db.commit()
|
||||||
|
return ChallengeResponse(identifier=server.identifier, challenge_uuid=challenge_uuid, expires_at=expires_at.isoformat())
|
||||||
|
|
||||||
|
|
||||||
@router.delete('/admin/servers/{server_id}', status_code=status.HTTP_204_NO_CONTENT)
|
@router.delete('/admin/servers/{server_id}', status_code=status.HTTP_204_NO_CONTENT)
|
||||||
def delete_server(server_id: int, db: Session = Depends(get_db), _: models.User = Depends(require_admin)):
|
def delete_server(server_id: int, db: Session = Depends(get_db), _: models.User = Depends(require_admin)):
|
||||||
obj = db.query(MonitoredServer).filter(MonitoredServer.id == server_id).first()
|
obj = db.query(MonitoredServer).filter(MonitoredServer.id == server_id).first()
|
||||||
@@ -133,12 +158,13 @@ def delete_server(server_id: int, db: Session = Depends(get_db), _: models.User
|
|||||||
state = db.query(ServerState).filter(ServerState.server_id == server_id).first()
|
state = db.query(ServerState).filter(ServerState.server_id == server_id).first()
|
||||||
if state:
|
if state:
|
||||||
db.delete(state)
|
db.delete(state)
|
||||||
|
db.query(ServerChallenge).filter(ServerChallenge.server_id == server_id).delete()
|
||||||
|
db.query(ServerHandshakeNonce).filter(ServerHandshakeNonce.server_id == server_id).delete()
|
||||||
db.delete(obj)
|
db.delete(obj)
|
||||||
db.commit()
|
db.commit()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
# Temporary ingestion endpoint before WS plugin lands
|
|
||||||
class ServerHeartbeat(BaseModel):
|
class ServerHeartbeat(BaseModel):
|
||||||
identifier: str
|
identifier: str
|
||||||
openclaw_version: str | None = None
|
openclaw_version: str | None = None
|
||||||
@@ -167,3 +193,76 @@ def server_heartbeat(payload: ServerHeartbeat, db: Session = Depends(get_db)):
|
|||||||
st.last_seen_at = datetime.now(timezone.utc)
|
st.last_seen_at = datetime.now(timezone.utc)
|
||||||
db.commit()
|
db.commit()
|
||||||
return {'ok': True, 'server_id': server.id, 'last_seen_at': st.last_seen_at}
|
return {'ok': True, 'server_id': server.id, 'last_seen_at': st.last_seen_at}
|
||||||
|
|
||||||
|
|
||||||
|
@router.websocket('/server/ws')
|
||||||
|
async def server_ws(websocket: WebSocket):
|
||||||
|
await websocket.accept()
|
||||||
|
db = SessionLocal()
|
||||||
|
server_id = None
|
||||||
|
try:
|
||||||
|
hello = await websocket.receive_json()
|
||||||
|
identifier = (hello.get('identifier') or '').strip()
|
||||||
|
challenge_uuid = (hello.get('challenge_uuid') or '').strip()
|
||||||
|
nonce = (hello.get('nonce') or '').strip()
|
||||||
|
|
||||||
|
if not identifier or not challenge_uuid or not nonce:
|
||||||
|
await websocket.close(code=4400)
|
||||||
|
return
|
||||||
|
|
||||||
|
server = db.query(MonitoredServer).filter(MonitoredServer.identifier == identifier, MonitoredServer.is_enabled == True).first()
|
||||||
|
if not server:
|
||||||
|
await websocket.close(code=4404)
|
||||||
|
return
|
||||||
|
|
||||||
|
ch = db.query(ServerChallenge).filter(ServerChallenge.challenge_uuid == challenge_uuid, ServerChallenge.server_id == server.id).first()
|
||||||
|
if not ch or ch.used_at is not None or ch.expires_at < datetime.now(timezone.utc):
|
||||||
|
await websocket.close(code=4401)
|
||||||
|
return
|
||||||
|
|
||||||
|
nonce_used = db.query(ServerHandshakeNonce).filter(ServerHandshakeNonce.server_id == server.id, ServerHandshakeNonce.nonce == nonce).first()
|
||||||
|
if nonce_used:
|
||||||
|
await websocket.close(code=4409)
|
||||||
|
return
|
||||||
|
|
||||||
|
db.add(ServerHandshakeNonce(server_id=server.id, nonce=nonce))
|
||||||
|
ch.used_at = datetime.now(timezone.utc)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
server_id = server.id
|
||||||
|
ACTIVE_WS[server.id] = websocket
|
||||||
|
await websocket.send_json({'ok': True, 'server_id': server.id, 'message': 'connected'})
|
||||||
|
|
||||||
|
while True:
|
||||||
|
msg = await websocket.receive_json()
|
||||||
|
event = msg.get('event')
|
||||||
|
payload = msg.get('payload') or {}
|
||||||
|
st = db.query(ServerState).filter(ServerState.server_id == server.id).first()
|
||||||
|
if not st:
|
||||||
|
st = ServerState(server_id=server.id)
|
||||||
|
db.add(st)
|
||||||
|
|
||||||
|
if event == 'server.hello':
|
||||||
|
st.openclaw_version = payload.get('openclaw_version')
|
||||||
|
st.agents_json = json.dumps(payload.get('agents') or [], ensure_ascii=False)
|
||||||
|
elif event in {'server.metrics', 'agent.status_changed'}:
|
||||||
|
st.cpu_pct = payload.get('cpu_pct', st.cpu_pct)
|
||||||
|
st.mem_pct = payload.get('mem_pct', st.mem_pct)
|
||||||
|
st.disk_pct = payload.get('disk_pct', st.disk_pct)
|
||||||
|
st.swap_pct = payload.get('swap_pct', st.swap_pct)
|
||||||
|
if 'agents' in payload:
|
||||||
|
st.agents_json = json.dumps(payload.get('agents') or [], ensure_ascii=False)
|
||||||
|
|
||||||
|
st.last_seen_at = datetime.now(timezone.utc)
|
||||||
|
db.commit()
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
await websocket.close(code=1011)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
if server_id and ACTIVE_WS.get(server_id) is websocket:
|
||||||
|
ACTIVE_WS.pop(server_id, None)
|
||||||
|
db.close()
|
||||||
|
|||||||
@@ -56,3 +56,23 @@ class ServerState(Base):
|
|||||||
swap_pct = Column(Float, nullable=True)
|
swap_pct = Column(Float, nullable=True)
|
||||||
last_seen_at = Column(DateTime(timezone=True), nullable=True)
|
last_seen_at = Column(DateTime(timezone=True), nullable=True)
|
||||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||||
|
|
||||||
|
|
||||||
|
class ServerChallenge(Base):
|
||||||
|
__tablename__ = 'server_challenges'
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
|
server_id = Column(Integer, ForeignKey('monitored_servers.id'), nullable=False, index=True)
|
||||||
|
challenge_uuid = Column(String(64), nullable=False, unique=True, index=True)
|
||||||
|
expires_at = Column(DateTime(timezone=True), nullable=False)
|
||||||
|
used_at = Column(DateTime(timezone=True), nullable=True)
|
||||||
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
|
||||||
|
|
||||||
|
class ServerHandshakeNonce(Base):
|
||||||
|
__tablename__ = 'server_handshake_nonces'
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
|
server_id = Column(Integer, ForeignKey('monitored_servers.id'), nullable=False, index=True)
|
||||||
|
nonce = Column(String(128), nullable=False, index=True)
|
||||||
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
|||||||
Reference in New Issue
Block a user