"""Keycloak JWT authentication middleware.""" import os from fastapi import HTTPException, Request from jose import jwt, JWTError, jwk from jose.utils import base64url_decode import httpx # Cache JWKS per (host, realm) to avoid fetching on every request _jwks_cache: dict[str, dict] = {} async def _get_jwks(kc_host: str, realm: str) -> dict: cache_key = f"{kc_host}/{realm}" if cache_key in _jwks_cache: return _jwks_cache[cache_key] url = f"{kc_host}/realms/{realm}/protocol/openid-connect/certs" async with httpx.AsyncClient(timeout=10) as client: resp = await client.get(url) if resp.status_code != 200: raise HTTPException( status_code=502, detail=f"无法获取 Keycloak JWKS: HTTP {resp.status_code}", ) data = resp.json() _jwks_cache[cache_key] = data return data def _find_rsa_key(jwks: dict, token: str) -> dict | None: """Find the matching RSA key from JWKS for the token's kid.""" unverified_header = jwt.get_unverified_header(token) kid = unverified_header.get("kid") for key in jwks.get("keys", []): if key.get("kid") == kid: return key return None async def verify_token(request: Request, kc_host: str, realm: str) -> dict: """Extract and verify the Bearer JWT from the Authorization header. Returns the decoded payload on success. Raises HTTPException(401) on missing/invalid token. """ auth_header = request.headers.get("Authorization", "") if not auth_header.startswith("Bearer "): raise HTTPException(status_code=401, detail="缺少 Authorization Bearer token") token = auth_header[7:] jwks = await _get_jwks(kc_host, realm) rsa_key = _find_rsa_key(jwks, token) if rsa_key is None: # Clear cache in case keys rotated _jwks_cache.pop(f"{kc_host}/{realm}", None) jwks = await _get_jwks(kc_host, realm) rsa_key = _find_rsa_key(jwks, token) if rsa_key is None: raise HTTPException(status_code=401, detail="无法匹配 JWT 签名密钥") try: payload = jwt.decode( token, rsa_key, algorithms=["RS256"], options={"verify_aud": False}, # Keycloak audience varies by client ) return payload except JWTError as e: raise HTTPException(status_code=401, detail=f"JWT 验证失败: {e}") async def require_auth(request: Request): """Verify Bearer JWT for write endpoints. Dev mode: passthrough (no auth required). Prod mode: validates JWT via Keycloak JWKS. """ if os.getenv("ENV_MODE", "dev") == "dev": return None from app.services.config_service import ConfigService config = ConfigService.load() kc = config.get("keycloak", {}) if not kc.get("host"): return None # KC not configured – allow access return await verify_token(request, kc["host"], kc.get("realm", "")) async def require_admin(request: Request, config: dict): """Verify the request carries a valid Keycloak JWT with admin role. Raises HTTPException(401/403) on failure. """ kc = config.get("keycloak", {}) kc_host = kc.get("host") realm = kc.get("realm") if not kc_host or not realm: raise HTTPException( status_code=503, detail="Keycloak 未配置,无法进行鉴权", ) payload = await verify_token(request, kc_host, realm) roles = payload.get("realm_access", {}).get("roles", []) if "admin" not in roles: raise HTTPException(status_code=403, detail="需要 admin 角色") return payload