80 lines
2.5 KiB
Python
80 lines
2.5 KiB
Python
#api/__init__.py
|
|
import os
|
|
from functools import wraps
|
|
from flask import jsonify, session, Blueprint, request, g
|
|
from flask_limiter import Limiter
|
|
from flask_limiter.util import get_remote_address
|
|
from jwt import decode, ExpiredSignatureError, InvalidTokenError
|
|
import importlib
|
|
import requests
|
|
from threading import Lock
|
|
|
|
_public_key_cache = None
|
|
_lock = Lock()
|
|
|
|
def keycloak_public_key():
|
|
global _public_key_cache
|
|
if _public_key_cache:
|
|
return _public_key_cache
|
|
with _lock:
|
|
if _public_key_cache:
|
|
return _public_key_cache
|
|
|
|
url = "https://login.hangman-lab.top/realms/Hangman-Lab/protocol/openid-connect/certs"
|
|
response = requests.get(url)
|
|
jwks = response.json()
|
|
public_key = jwks["keys"][0]["x5c"][0]
|
|
_public_key_cache = f"-----BEGIN CERTIFICATE-----\n{public_key}\n-----END CERTIFICATE-----"
|
|
return _public_key_cache
|
|
|
|
def verify_token(token):
|
|
try:
|
|
decoded = decode(
|
|
token,
|
|
keycloak_public_key(),
|
|
algorithms=["RS256"],
|
|
audience="labdev"
|
|
)
|
|
return decoded
|
|
except ExpiredSignatureError:
|
|
return None
|
|
except InvalidTokenError:
|
|
return None
|
|
|
|
def require_auth(roles=[]):
|
|
def decorator(func):
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
auth_header = request.headers.get('Authorization')
|
|
if not auth_header or not auth_header.startswith('Bearer'):
|
|
return jsonify({"error": "Unauthorized"}), 401
|
|
token = auth_header.split(" ")[1]
|
|
decoded = verify_token(token)
|
|
if not decoded:
|
|
return jsonify({"error": "Invalid or expired token"}), 401
|
|
user_roles = decoded.get("roles", [])
|
|
if roles and not set(roles).issubset(set(user_roles)):
|
|
return jsonify({"error": "Forbidden, permission denied"}), 403
|
|
g.user = decoded
|
|
return func(*args, **kwargs)
|
|
return wrapper
|
|
return decorator
|
|
|
|
|
|
limiter = Limiter(
|
|
key_func=get_remote_address,
|
|
default_limits=["100 per minute"]
|
|
)
|
|
|
|
|
|
def register_blueprints(app):
|
|
current_dir = os.path.dirname(__file__)
|
|
for filename in os.listdir(current_dir):
|
|
if filename == "__init__.py" or not filename.endswith(".py"):
|
|
continue
|
|
module_name = filename[:-3]
|
|
module = importlib.import_module(f"api.{module_name}")
|
|
for attr in dir(module):
|
|
bp = getattr(module, attr)
|
|
if isinstance(bp, Blueprint):
|
|
app.register_blueprint(bp) |