Files
HangmanLab.Backend/api/__init__.py
2024-12-06 16:12:44 +00:00

138 lines
4.0 KiB
Python

#api/__init__.py
import base64
import os
from functools import wraps
from cryptography import x509
from cryptography.hazmat.primitives import serialization
from flask import jsonify, Blueprint, request
from flask_limiter import Limiter
from flask_limiter.util import get_remote_address
from jwt import decode, ExpiredSignatureError, InvalidTokenError, get_unverified_header
import importlib
import requests
from threading import Lock
import env_provider
_public_key_cache = {}
_lock = Lock()
def x5c_to_public_key(x5c):
cert_der = base64.b64decode(x5c)
cert = x509.load_der_x509_certificate(cert_der)
public_key = cert.public_key()
pem = public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
)
return pem
def get_jwks():
url = f"{env_provider.KC_HOST}/realms/{env_provider.KC_REALM}/protocol/openid-connect/certs"
response = requests.get(url)
jwks = response.json()
return jwks
def get_public_key_for_kid(kid):
global _public_key_cache
if kid in _public_key_cache:
return _public_key_cache[kid]
jwks = get_jwks()
res = []
for key_data in jwks["keys"]:
if key_data["kid"] == kid and key_data["use"] == "sig" and key_data["alg"] == "RS256" and key_data["kty"] == "RSA":
x5c = key_data["x5c"][0]
pem_public_key = x5c_to_public_key(x5c)
_public_key_cache[kid] = pem_public_key
res.append(pem_public_key)
if len(res) > 0:
print(len(res))
return res[0]
return None
def verify_token(token):
try:
header = get_unverified_header(token)
kid = header.get("kid")
if not kid:
return None
public_key = get_public_key_for_kid(kid)
if not public_key:
return None
decoded = decode(
token,
public_key,
algorithms=["RS256"],
audience=env_provider.KC_CLIENT_ID
)
return decoded
except ExpiredSignatureError as e:
print(e)
return None
except InvalidTokenError as e:
print(e)
return None
def require_auth(roles=[]):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
if request.method == "OPTIONS":
return '', 200
auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith('Bearer'):
return jsonify({"error": "Unauthorized"}), 401
token = auth_header.split(" ")[1]
decoded = verify_token(token)
if not decoded:
return jsonify({"error": "Invalid or expired token"}), 401
user_roles = decoded.get("resource_access", {}).get(env_provider.KC_CLIENT_ID, {}).get("roles", [])
if roles and not (set(roles) & set(user_roles)):
print("auth failed")
return jsonify({"error": "Forbidden, permission denied"}), 403
print("auth success")
return func(*args, **kwargs)
return wrapper
return decorator
rate_limits = {}
default_rate_limit = "60 per minute"
def init_rate_limits(app):
global rate_limits
rate_limits = {
f"{rule.rule} : {method}": default_rate_limit
for rule in app.url_map.iter_rules()
for method in rule.methods
}
def get_rate_limit():
key = f"{request.path} : {request.method}"
return rate_limits.get(key, default_rate_limit)
limiter = Limiter(
key_func=get_remote_address,
default_limits=["100 per minute"]
)
def register_blueprints(app):
current_dir = os.path.dirname(__file__)
for filename in os.listdir(current_dir):
if filename == "__init__.py" or not filename.endswith(".py"):
continue
module_name = filename[:-3]
module = importlib.import_module(f"api.{module_name}")
for attr in dir(module):
bp = getattr(module, attr)
if isinstance(bp, Blueprint):
app.register_blueprint(bp)