115 lines
3.6 KiB
Python
115 lines
3.6 KiB
Python
"""Keycloak JWT authentication middleware."""
|
||
|
||
import os
|
||
from fastapi import HTTPException, Request
|
||
from jose import jwt, JWTError, jwk
|
||
from jose.utils import base64url_decode
|
||
import httpx
|
||
|
||
# Cache JWKS per (host, realm) to avoid fetching on every request
|
||
_jwks_cache: dict[str, dict] = {}
|
||
|
||
|
||
async def _get_jwks(kc_host: str, realm: str) -> dict:
|
||
cache_key = f"{kc_host}/{realm}"
|
||
if cache_key in _jwks_cache:
|
||
return _jwks_cache[cache_key]
|
||
|
||
url = f"{kc_host}/realms/{realm}/protocol/openid-connect/certs"
|
||
async with httpx.AsyncClient(timeout=10) as client:
|
||
resp = await client.get(url)
|
||
if resp.status_code != 200:
|
||
raise HTTPException(
|
||
status_code=502,
|
||
detail=f"无法获取 Keycloak JWKS: HTTP {resp.status_code}",
|
||
)
|
||
data = resp.json()
|
||
_jwks_cache[cache_key] = data
|
||
return data
|
||
|
||
|
||
def _find_rsa_key(jwks: dict, token: str) -> dict | None:
|
||
"""Find the matching RSA key from JWKS for the token's kid."""
|
||
unverified_header = jwt.get_unverified_header(token)
|
||
kid = unverified_header.get("kid")
|
||
for key in jwks.get("keys", []):
|
||
if key.get("kid") == kid:
|
||
return key
|
||
return None
|
||
|
||
|
||
async def verify_token(request: Request, kc_host: str, realm: str) -> dict:
|
||
"""Extract and verify the Bearer JWT from the Authorization header.
|
||
|
||
Returns the decoded payload on success.
|
||
Raises HTTPException(401) on missing/invalid token.
|
||
"""
|
||
auth_header = request.headers.get("Authorization", "")
|
||
if not auth_header.startswith("Bearer "):
|
||
raise HTTPException(status_code=401, detail="缺少 Authorization Bearer token")
|
||
|
||
token = auth_header[7:]
|
||
|
||
jwks = await _get_jwks(kc_host, realm)
|
||
rsa_key = _find_rsa_key(jwks, token)
|
||
if rsa_key is None:
|
||
# Clear cache in case keys rotated
|
||
_jwks_cache.pop(f"{kc_host}/{realm}", None)
|
||
jwks = await _get_jwks(kc_host, realm)
|
||
rsa_key = _find_rsa_key(jwks, token)
|
||
if rsa_key is None:
|
||
raise HTTPException(status_code=401, detail="无法匹配 JWT 签名密钥")
|
||
|
||
try:
|
||
payload = jwt.decode(
|
||
token,
|
||
rsa_key,
|
||
algorithms=["RS256"],
|
||
options={"verify_aud": False}, # Keycloak audience varies by client
|
||
)
|
||
return payload
|
||
except JWTError as e:
|
||
raise HTTPException(status_code=401, detail=f"JWT 验证失败: {e}")
|
||
|
||
|
||
async def require_auth(request: Request):
|
||
"""Verify Bearer JWT for write endpoints.
|
||
|
||
Dev mode: passthrough (no auth required).
|
||
Prod mode: validates JWT via Keycloak JWKS.
|
||
"""
|
||
if os.getenv("ENV_MODE", "dev") == "dev":
|
||
return None
|
||
|
||
from app.services.config_service import ConfigService
|
||
config = ConfigService.load()
|
||
kc = config.get("keycloak", {})
|
||
if not kc.get("host"):
|
||
return None # KC not configured – allow access
|
||
|
||
return await verify_token(request, kc["host"], kc.get("realm", ""))
|
||
|
||
|
||
async def require_admin(request: Request, config: dict):
|
||
"""Verify the request carries a valid Keycloak JWT with admin role.
|
||
|
||
Raises HTTPException(401/403) on failure.
|
||
"""
|
||
kc = config.get("keycloak", {})
|
||
kc_host = kc.get("host")
|
||
realm = kc.get("realm")
|
||
|
||
if not kc_host or not realm:
|
||
raise HTTPException(
|
||
status_code=503,
|
||
detail="Keycloak 未配置,无法进行鉴权",
|
||
)
|
||
|
||
payload = await verify_token(request, kc_host, realm)
|
||
|
||
roles = payload.get("realm_access", {}).get("roles", [])
|
||
if "admin" not in roles:
|
||
raise HTTPException(status_code=403, detail="需要 admin 角色")
|
||
|
||
return payload
|