Files
HarborForge.Backend/app/api/routers/monitor.py
zhi cf8a43d5b2 feat(monitor): add API Key authentication for server heartbeat
- Add api_key field to MonitoredServer model
- Add migration to create api_key column with unique index
- Add /admin/servers/{id}/api-key endpoint for key generation
- Add /admin/servers/{id}/api-key DELETE endpoint for revocation
- Add /server/heartbeat-v2 endpoint with X-API-Key header authentication
- Add TelemetryPayload model with extended fields (load_avg, uptime_seconds)
- Add basic unit tests for API key functionality
2026-03-19 15:32:32 +00:00

356 lines
14 KiB
Python

from datetime import datetime, timedelta, timezone
import json
import secrets
import uuid
from typing import List, Dict
from fastapi import APIRouter, Depends, Header, HTTPException, status, WebSocket, WebSocketDisconnect
from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.core.config import get_db, SessionLocal
from app.api.deps import get_current_user_or_apikey
from app.models import models
from app.models.monitor import (
ProviderAccount,
MonitoredServer,
ServerState,
ServerChallenge,
ServerHandshakeNonce,
)
from app.services.monitoring import (
get_task_stats_cached,
get_provider_usage_view,
get_server_states_view,
test_provider_connection,
)
from app.services.crypto_box import get_public_key_info, decrypt_payload_b64, ts_within
router = APIRouter(prefix='/monitor', tags=['Monitor'])
SUPPORTED_PROVIDERS = {'anthropic', 'openai', 'minimax', 'kimi', 'qwen'}
ACTIVE_WS: Dict[int, WebSocket] = {}
class ProviderAccountCreate(BaseModel):
provider: str
label: str
credential: str
class ProviderTestRequest(BaseModel):
provider: str
credential: str
class MonitoredServerCreate(BaseModel):
identifier: str
display_name: str | None = None
class ChallengeResponse(BaseModel):
identifier: str
challenge_uuid: str
expires_at: str
def require_admin(current_user: models.User = Depends(get_current_user_or_apikey)):
if not current_user.is_admin:
raise HTTPException(status_code=403, detail='Admin required')
return current_user
@router.get('/public/server-public-key')
def monitor_public_key():
return get_public_key_info()
@router.get('/public/overview')
def public_overview(db: Session = Depends(get_db)):
return {
'tasks': get_task_stats_cached(db, ttl_seconds=1800),
'providers': get_provider_usage_view(db),
'servers': get_server_states_view(db, offline_after_minutes=7),
'generated_at': datetime.now(timezone.utc).isoformat(),
}
@router.get('/admin/providers/accounts')
def list_provider_accounts(db: Session = Depends(get_db), _: models.User = Depends(require_admin)):
accounts = db.query(ProviderAccount).order_by(ProviderAccount.created_at.desc()).all()
return [
{
'id': a.id,
'provider': a.provider,
'label': a.label,
'is_enabled': a.is_enabled,
'created_at': a.created_at,
'credential_masked': '***' + (a.credential[-4:] if a.credential else ''),
}
for a in accounts
]
@router.post('/admin/providers/accounts', status_code=status.HTTP_201_CREATED)
def create_provider_account(payload: ProviderAccountCreate, db: Session = Depends(get_db), user: models.User = Depends(require_admin)):
provider = payload.provider.lower().strip()
if provider not in SUPPORTED_PROVIDERS:
raise HTTPException(status_code=400, detail=f'Unsupported provider: {provider}')
obj = ProviderAccount(
provider=provider,
label=payload.label.strip(),
credential=payload.credential.strip(),
is_enabled=True,
created_by=user.id,
)
db.add(obj)
db.commit()
db.refresh(obj)
return {'id': obj.id, 'provider': obj.provider, 'label': obj.label, 'is_enabled': obj.is_enabled}
@router.post('/admin/providers/test')
def test_provider(payload: ProviderTestRequest, _: models.User = Depends(require_admin)):
ok, message = test_provider_connection(payload.provider.lower().strip(), payload.credential.strip())
return {'ok': ok, 'message': message}
@router.delete('/admin/providers/accounts/{account_id}', status_code=status.HTTP_204_NO_CONTENT)
def delete_provider_account(account_id: int, db: Session = Depends(get_db), _: models.User = Depends(require_admin)):
obj = db.query(ProviderAccount).filter(ProviderAccount.id == account_id).first()
if not obj:
raise HTTPException(status_code=404, detail='Provider account not found')
db.delete(obj)
db.commit()
return None
@router.get('/admin/servers')
def list_servers(db: Session = Depends(get_db), _: models.User = Depends(require_admin)):
return get_server_states_view(db, offline_after_minutes=7)
@router.post('/admin/servers', status_code=status.HTTP_201_CREATED)
def add_server(payload: MonitoredServerCreate, db: Session = Depends(get_db), user: models.User = Depends(require_admin)):
identifier = payload.identifier.strip()
if not identifier:
raise HTTPException(status_code=400, detail='identifier required')
exists = db.query(MonitoredServer).filter(MonitoredServer.identifier == identifier).first()
if exists:
raise HTTPException(status_code=400, detail='identifier already exists')
obj = MonitoredServer(identifier=identifier, display_name=payload.display_name, is_enabled=True, created_by=user.id)
db.add(obj)
db.commit()
db.refresh(obj)
return {'id': obj.id, 'identifier': obj.identifier, 'display_name': obj.display_name, 'is_enabled': obj.is_enabled}
@router.post('/admin/servers/{server_id}/challenge', response_model=ChallengeResponse)
def issue_server_challenge(server_id: int, db: Session = Depends(get_db), _: models.User = Depends(require_admin)):
server = db.query(MonitoredServer).filter(MonitoredServer.id == server_id).first()
if not server:
raise HTTPException(status_code=404, detail='Server not found')
challenge_uuid = str(uuid.uuid4())
expires_at = datetime.now(timezone.utc) + timedelta(minutes=10)
ch = ServerChallenge(server_id=server_id, challenge_uuid=challenge_uuid, expires_at=expires_at)
db.add(ch)
db.commit()
return ChallengeResponse(identifier=server.identifier, challenge_uuid=challenge_uuid, expires_at=expires_at.isoformat())
@router.delete('/admin/servers/{server_id}', status_code=status.HTTP_204_NO_CONTENT)
def delete_server(server_id: int, db: Session = Depends(get_db), _: models.User = Depends(require_admin)):
obj = db.query(MonitoredServer).filter(MonitoredServer.id == server_id).first()
if not obj:
raise HTTPException(status_code=404, detail='Server not found')
state = db.query(ServerState).filter(ServerState.server_id == server_id).first()
if state:
db.delete(state)
db.query(ServerChallenge).filter(ServerChallenge.server_id == server_id).delete()
db.query(ServerHandshakeNonce).filter(ServerHandshakeNonce.server_id == server_id).delete()
db.delete(obj)
db.commit()
return None
@router.post('/admin/servers/{server_id}/api-key')
def generate_api_key(server_id: int, db: Session = Depends(get_db), _: models.User = Depends(require_admin)):
"""Generate or regenerate API Key for a server (heartbeat v2)"""
server = db.query(MonitoredServer).filter(MonitoredServer.id == server_id).first()
if not server:
raise HTTPException(status_code=404, detail='Server not found')
# Generate new API key (32 bytes = ~43 chars base64)
api_key = secrets.token_urlsafe(32)
server.api_key = api_key
db.commit()
return {'server_id': server.id, 'api_key': api_key, 'message': 'Store this key securely - it will not be shown again'}
@router.delete('/admin/servers/{server_id}/api-key', status_code=status.HTTP_204_NO_CONTENT)
def revoke_api_key(server_id: int, db: Session = Depends(get_db), _: models.User = Depends(require_admin)):
"""Revoke API Key for a server"""
server = db.query(MonitoredServer).filter(MonitoredServer.id == server_id).first()
if not server:
raise HTTPException(status_code=404, detail='Server not found')
server.api_key = None
db.commit()
return None
class ServerHeartbeat(BaseModel):
identifier: str
openclaw_version: str | None = None
agents: List[dict] = []
cpu_pct: float | None = None
mem_pct: float | None = None
disk_pct: float | None = None
swap_pct: float | None = None
@router.post('/server/heartbeat')
def server_heartbeat(payload: ServerHeartbeat, db: Session = Depends(get_db)):
server = db.query(MonitoredServer).filter(MonitoredServer.identifier == payload.identifier, MonitoredServer.is_enabled == True).first()
if not server:
raise HTTPException(status_code=404, detail='unknown server identifier')
st = db.query(ServerState).filter(ServerState.server_id == server.id).first()
if not st:
st = ServerState(server_id=server.id)
db.add(st)
st.openclaw_version = payload.openclaw_version
st.agents_json = json.dumps(payload.agents, ensure_ascii=False)
st.cpu_pct = payload.cpu_pct
st.mem_pct = payload.mem_pct
st.disk_pct = payload.disk_pct
st.swap_pct = payload.swap_pct
st.last_seen_at = datetime.now(timezone.utc)
db.commit()
return {'ok': True, 'server_id': server.id, 'last_seen_at': st.last_seen_at}
# Heartbeat v2 with API Key authentication
class TelemetryPayload(BaseModel):
identifier: str
openclaw_version: str | None = None
agents: List[dict] = []
cpu_pct: float | None = None
mem_pct: float | None = None
disk_pct: float | None = None
swap_pct: float | None = None
load_avg: list[float] | None = None
uptime_seconds: int | None = None
@router.post('/server/heartbeat-v2')
def server_heartbeat_v2(
payload: TelemetryPayload,
x_api_key: str = Header(..., alias='X-API-Key', description='API Key from /admin/servers/{id}/api-key'),
db: Session = Depends(get_db)
):
"""Server heartbeat using API Key authentication (no challenge_uuid required)"""
# Validate API key
server = db.query(MonitoredServer).filter(
MonitoredServer.api_key == x_api_key,
MonitoredServer.is_enabled == True
).first()
if not server:
raise HTTPException(status_code=401, detail='Invalid or missing API Key')
# Update server state
st = db.query(ServerState).filter(ServerState.server_id == server.id).first()
if not st:
st = ServerState(server_id=server.id)
db.add(st)
st.openclaw_version = payload.openclaw_version
st.agents_json = json.dumps(payload.agents, ensure_ascii=False)
st.cpu_pct = payload.cpu_pct
st.mem_pct = payload.mem_pct
st.disk_pct = payload.disk_pct
st.swap_pct = payload.swap_pct
st.last_seen_at = datetime.now(timezone.utc)
db.commit()
return {'ok': True, 'server_id': server.id, 'identifier': server.identifier, 'last_seen_at': st.last_seen_at}
@router.websocket('/server/ws')
async def server_ws(websocket: WebSocket):
await websocket.accept()
db = SessionLocal()
server_id = None
try:
hello = await websocket.receive_json()
encrypted_payload = (hello.get('encrypted_payload') or '').strip()
if encrypted_payload:
data = decrypt_payload_b64(encrypted_payload)
identifier = (data.get('identifier') or '').strip()
challenge_uuid = (data.get('challenge_uuid') or '').strip()
nonce = (data.get('nonce') or '').strip()
ts = data.get('ts')
if not ts_within(ts, max_minutes=10):
await websocket.close(code=4401)
return
else:
# backward compatible mode
identifier = (hello.get('identifier') or '').strip()
challenge_uuid = (hello.get('challenge_uuid') or '').strip()
nonce = (hello.get('nonce') or '').strip()
if not identifier or not challenge_uuid or not nonce:
await websocket.close(code=4400)
return
server = db.query(MonitoredServer).filter(MonitoredServer.identifier == identifier, MonitoredServer.is_enabled == True).first()
if not server:
await websocket.close(code=4404)
return
ch = db.query(ServerChallenge).filter(ServerChallenge.challenge_uuid == challenge_uuid, ServerChallenge.server_id == server.id).first()
if not ch or ch.used_at is not None or ch.expires_at < datetime.now(timezone.utc):
await websocket.close(code=4401)
return
nonce_used = db.query(ServerHandshakeNonce).filter(ServerHandshakeNonce.server_id == server.id, ServerHandshakeNonce.nonce == nonce).first()
if nonce_used:
await websocket.close(code=4409)
return
db.add(ServerHandshakeNonce(server_id=server.id, nonce=nonce))
ch.used_at = datetime.now(timezone.utc)
db.commit()
server_id = server.id
ACTIVE_WS[server.id] = websocket
await websocket.send_json({'ok': True, 'server_id': server.id, 'message': 'connected'})
while True:
msg = await websocket.receive_json()
event = msg.get('event')
payload = msg.get('payload') or {}
st = db.query(ServerState).filter(ServerState.server_id == server.id).first()
if not st:
st = ServerState(server_id=server.id)
db.add(st)
if event == 'server.hello':
st.openclaw_version = payload.get('openclaw_version')
st.agents_json = json.dumps(payload.get('agents') or [], ensure_ascii=False)
elif event in {'server.metrics', 'agent.status_changed'}:
st.cpu_pct = payload.get('cpu_pct', st.cpu_pct)
st.mem_pct = payload.get('mem_pct', st.mem_pct)
st.disk_pct = payload.get('disk_pct', st.disk_pct)
st.swap_pct = payload.get('swap_pct', st.swap_pct)
if 'agents' in payload:
st.agents_json = json.dumps(payload.get('agents') or [], ensure_ascii=False)
st.last_seen_at = datetime.now(timezone.utc)
db.commit()
except WebSocketDisconnect:
pass
except Exception:
try:
await websocket.close(code=1011)
except Exception:
pass
finally:
if server_id and ACTIVE_WS.get(server_id) is websocket:
ACTIVE_WS.pop(server_id, None)
db.close()