Files
HangmanLab.Backend/api/__init__.py
2024-12-04 14:06:30 +00:00

80 lines
2.5 KiB
Python

#api/__init__.py
import os
from functools import wraps
from flask import jsonify, session, Blueprint, request, g
from flask_limiter import Limiter
from flask_limiter.util import get_remote_address
from jwt import decode, ExpiredSignatureError, InvalidTokenError
import importlib
import requests
from threading import Lock
_public_key_cache = None
_lock = Lock()
def keycloak_public_key():
global _public_key_cache
if _public_key_cache:
return _public_key_cache
with _lock:
if _public_key_cache:
return _public_key_cache
url = "https://login.hangman-lab.top/realms/Hangman-Lab/protocol/openid-connect/certs"
response = requests.get(url)
jwks = response.json()
public_key = jwks["keys"][0]["x5c"][0]
_public_key_cache = f"-----BEGIN CERTIFICATE-----\n{public_key}\n-----END CERTIFICATE-----"
return _public_key_cache
def verify_token(token):
try:
decoded = decode(
token,
keycloak_public_key(),
algorithms=["RS256"],
audience="labdev"
)
return decoded
except ExpiredSignatureError:
return None
except InvalidTokenError:
return None
def require_auth(roles=[]):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith('Bearer'):
return jsonify({"error": "Unauthorized"}), 401
token = auth_header.split(" ")[1]
decoded = verify_token(token)
if not decoded:
return jsonify({"error": "Invalid or expired token"}), 401
user_roles = decoded.get("roles", [])
if roles and not set(roles).issubset(set(user_roles)):
return jsonify({"error": "Forbidden, permission denied"}), 403
g.user = decoded
return func(*args, **kwargs)
return wrapper
return decorator
limiter = Limiter(
key_func=get_remote_address,
default_limits=["100 per minute"]
)
def register_blueprints(app):
current_dir = os.path.dirname(__file__)
for filename in os.listdir(current_dir):
if filename == "__init__.py" or not filename.endswith(".py"):
continue
module_name = filename[:-3]
module = importlib.import_module(f"api.{module_name}")
for attr in dir(module):
bp = getattr(module, attr)
if isinstance(bp, Blueprint):
app.register_blueprint(bp)