import base64 import os from functools import wraps from cryptography import x509 from cryptography.hazmat.primitives import serialization from flask import jsonify, Blueprint, request, make_response from flask_limiter import Limiter from flask_limiter.util import get_remote_address from jwt import decode, ExpiredSignatureError, InvalidTokenError, get_unverified_header import importlib import requests from threading import Lock import env_provider import hashlib import json _public_key_cache = {} _lock = Lock() def x5c_to_public_key(x5c): cert_der = base64.b64decode(x5c) cert = x509.load_der_x509_certificate(cert_der) public_key = cert.public_key() pem = public_key.public_bytes( encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo ) return pem def get_jwks(): url = f"{env_provider.KC_HOST}/realms/{env_provider.KC_REALM}/protocol/openid-connect/certs" response = requests.get(url) jwks = response.json() return jwks def get_public_key_for_kid(kid): global _public_key_cache with _lock: if kid in _public_key_cache: return _public_key_cache[kid] jwks = get_jwks() res = [] for key_data in jwks["keys"]: if key_data["kid"] == kid and key_data["use"] == "sig" and key_data["alg"] == "RS256" and key_data["kty"] == "RSA": x5c = key_data["x5c"][0] pem_public_key = x5c_to_public_key(x5c) _public_key_cache[kid] = pem_public_key res.append(pem_public_key) if len(res) > 0: return res[0] return None def verify_token(token): try: header = get_unverified_header(token) kid = header.get("kid") if not kid: return None public_key = get_public_key_for_kid(kid) if not public_key: return None decoded = decode( token, public_key, algorithms=["RS256"], audience=env_provider.KC_CLIENT_ID ) return decoded except ExpiredSignatureError as e: print(e) return None except InvalidTokenError as e: print(e) return None def require_auth(roles=[]): def decorator(func): @wraps(func) def wrapper(*args, **kwargs): if request.method == "OPTIONS": return '', 200 auth_header = request.headers.get('Authorization') if not auth_header or not auth_header.startswith('Bearer'): return jsonify({"error": "Unauthorized"}), 401 token = auth_header.split(" ")[1] decoded = verify_token(token) if not decoded: return jsonify({"error": "Invalid or expired token"}), 401 user_roles = decoded.get("resource_access", {}).get(env_provider.KC_CLIENT_ID, {}).get("roles", []) if roles and not (set(roles) & set(user_roles)): print("auth failed") return jsonify({"error": "Forbidden, permission denied"}), 403 print("auth success") return func(*args, **kwargs) return wrapper return decorator rate_limits = {} default_rate_limit = "60 per minute" def init_rate_limits(app): global rate_limits rate_limits = { f"{rule.rule} : {method}": default_rate_limit for rule in app.url_map.iter_rules() for method in rule.methods } def get_rate_limit(): key = f"{request.path} : {request.method}" return rate_limits.get(key, default_rate_limit) limiter = Limiter( key_func=get_remote_address, default_limits=["100 per minute"] ) def register_blueprints(app): current_dir = os.path.dirname(__file__) for filename in os.listdir(current_dir): if filename == "__init__.py" or not filename.endswith(".py"): continue module_name = filename[:-3] module = importlib.import_module(f"api.{module_name}") for attr in dir(module): bp = getattr(module, attr) if isinstance(bp, Blueprint): app.register_blueprint(bp) def generate_etag(data): serialized = json.dumps(data, sort_keys=True).encode('utf-8') return hashlib.md5(serialized).hexdigest() def etag_response(f): @wraps(f) def decorator(*args, **kwargs): response = f(*args, **kwargs) if response[1] in (200, 201): if isinstance(response[0], (dict, list)): etag = generate_etag(response[0]) if_none_match = request.headers.get("if_none_match") if if_none_match == etag: return jsonify({}), 200 resp = make_response(response[0], response[1]) resp.headers["ETag"] = etag return resp return response return decorator