init
This commit is contained in:
0
middleware/__init__.py
Normal file
0
middleware/__init__.py
Normal file
114
middleware/auth.py
Normal file
114
middleware/auth.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""Keycloak JWT authentication middleware."""
|
||||
|
||||
import os
|
||||
from fastapi import HTTPException, Request
|
||||
from jose import jwt, JWTError, jwk
|
||||
from jose.utils import base64url_decode
|
||||
import httpx
|
||||
|
||||
# Cache JWKS per (host, realm) to avoid fetching on every request
|
||||
_jwks_cache: dict[str, dict] = {}
|
||||
|
||||
|
||||
async def _get_jwks(kc_host: str, realm: str) -> dict:
|
||||
cache_key = f"{kc_host}/{realm}"
|
||||
if cache_key in _jwks_cache:
|
||||
return _jwks_cache[cache_key]
|
||||
|
||||
url = f"{kc_host}/realms/{realm}/protocol/openid-connect/certs"
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
resp = await client.get(url)
|
||||
if resp.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"无法获取 Keycloak JWKS: HTTP {resp.status_code}",
|
||||
)
|
||||
data = resp.json()
|
||||
_jwks_cache[cache_key] = data
|
||||
return data
|
||||
|
||||
|
||||
def _find_rsa_key(jwks: dict, token: str) -> dict | None:
|
||||
"""Find the matching RSA key from JWKS for the token's kid."""
|
||||
unverified_header = jwt.get_unverified_header(token)
|
||||
kid = unverified_header.get("kid")
|
||||
for key in jwks.get("keys", []):
|
||||
if key.get("kid") == kid:
|
||||
return key
|
||||
return None
|
||||
|
||||
|
||||
async def verify_token(request: Request, kc_host: str, realm: str) -> dict:
|
||||
"""Extract and verify the Bearer JWT from the Authorization header.
|
||||
|
||||
Returns the decoded payload on success.
|
||||
Raises HTTPException(401) on missing/invalid token.
|
||||
"""
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if not auth_header.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="缺少 Authorization Bearer token")
|
||||
|
||||
token = auth_header[7:]
|
||||
|
||||
jwks = await _get_jwks(kc_host, realm)
|
||||
rsa_key = _find_rsa_key(jwks, token)
|
||||
if rsa_key is None:
|
||||
# Clear cache in case keys rotated
|
||||
_jwks_cache.pop(f"{kc_host}/{realm}", None)
|
||||
jwks = await _get_jwks(kc_host, realm)
|
||||
rsa_key = _find_rsa_key(jwks, token)
|
||||
if rsa_key is None:
|
||||
raise HTTPException(status_code=401, detail="无法匹配 JWT 签名密钥")
|
||||
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
rsa_key,
|
||||
algorithms=["RS256"],
|
||||
options={"verify_aud": False}, # Keycloak audience varies by client
|
||||
)
|
||||
return payload
|
||||
except JWTError as e:
|
||||
raise HTTPException(status_code=401, detail=f"JWT 验证失败: {e}")
|
||||
|
||||
|
||||
async def require_auth(request: Request):
|
||||
"""Verify Bearer JWT for write endpoints.
|
||||
|
||||
Dev mode: passthrough (no auth required).
|
||||
Prod mode: validates JWT via Keycloak JWKS.
|
||||
"""
|
||||
if os.getenv("ENV_MODE", "dev") == "dev":
|
||||
return None
|
||||
|
||||
from app.services.config_service import ConfigService
|
||||
config = ConfigService.load()
|
||||
kc = config.get("keycloak", {})
|
||||
if not kc.get("host"):
|
||||
return None # KC not configured – allow access
|
||||
|
||||
return await verify_token(request, kc["host"], kc.get("realm", ""))
|
||||
|
||||
|
||||
async def require_admin(request: Request, config: dict):
|
||||
"""Verify the request carries a valid Keycloak JWT with admin role.
|
||||
|
||||
Raises HTTPException(401/403) on failure.
|
||||
"""
|
||||
kc = config.get("keycloak", {})
|
||||
kc_host = kc.get("host")
|
||||
realm = kc.get("realm")
|
||||
|
||||
if not kc_host or not realm:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Keycloak 未配置,无法进行鉴权",
|
||||
)
|
||||
|
||||
payload = await verify_token(request, kc_host, realm)
|
||||
|
||||
roles = payload.get("realm_access", {}).get("roles", [])
|
||||
if "admin" not in roles:
|
||||
raise HTTPException(status_code=403, detail="需要 admin 角色")
|
||||
|
||||
return payload
|
||||
32
middleware/config_guard.py
Normal file
32
middleware/config_guard.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from services.config_service import ConfigService
|
||||
|
||||
# Paths that are always accessible, even when DB is not configured
|
||||
_ALLOWED_PREFIXES = ("/api/setup", "/docs", "/openapi", "/redoc")
|
||||
|
||||
|
||||
class ConfigGuardMiddleware(BaseHTTPMiddleware):
|
||||
"""Return 503 for all business routes when the database is not configured."""
|
||||
|
||||
async def dispatch(self, request, call_next):
|
||||
path = request.url.path
|
||||
|
||||
# Always allow: setup routes, root, docs, OPTIONS (CORS preflight)
|
||||
if path == "/" or request.method == "OPTIONS":
|
||||
return await call_next(request)
|
||||
for prefix in _ALLOWED_PREFIXES:
|
||||
if path.startswith(prefix):
|
||||
return await call_next(request)
|
||||
|
||||
if not ConfigService.is_db_configured():
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content={
|
||||
"error_code": "SERVICE_NOT_CONFIGURED",
|
||||
"detail": "数据库未配置,请先完成系统初始化",
|
||||
},
|
||||
)
|
||||
|
||||
return await call_next(request)
|
||||
Reference in New Issue
Block a user