from functools import wraps from datetime import datetime, UTC from cryptography import x509 from cryptography.hazmat.primitives import serialization from flask import jsonify, Blueprint, request, make_response from flask_limiter import Limiter from flask_limiter.util import get_remote_address from jwt import decode, ExpiredSignatureError, InvalidTokenError, get_unverified_header from threading import Lock from db.models.APIKey import APIKey from db import get_db import base64 import os import pkgutil import secrets import string import env_provider import hashlib import json import importlib import requests _public_key_cache = {} _lock = Lock() def x5c_to_public_key(x5c): cert_der = base64.b64decode(x5c) cert = x509.load_der_x509_certificate(cert_der) public_key = cert.public_key() pem = public_key.public_bytes( encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo ) return pem def get_jwks(): url = f"{env_provider.KC_HOST}/realms/{env_provider.KC_REALM}/protocol/openid-connect/certs" response = requests.get(url) jwks = response.json() return jwks def get_public_key_for_kid(kid): global _public_key_cache with _lock: if kid in _public_key_cache: return _public_key_cache[kid] jwks = get_jwks() res = [] for key_data in jwks["keys"]: if key_data["kid"] == kid and key_data["use"] == "sig" and key_data["alg"] == "RS256" and key_data["kty"] == "RSA": x5c = key_data["x5c"][0] pem_public_key = x5c_to_public_key(x5c) _public_key_cache[kid] = pem_public_key res.append(pem_public_key) if len(res) > 0: return res[0] return None def verify_token(token): try: header = get_unverified_header(token) kid = header.get("kid") if not kid: return None public_key = get_public_key_for_kid(kid) if not public_key: return None decoded = decode( token, public_key, algorithms=["RS256"], audience=env_provider.KC_CLIENT_ID ) return decoded except ExpiredSignatureError as e: print(e) return None except InvalidTokenError as e: print(e) return None def is_user_admin(): is_admin = False auth_header = request.headers.get('Authorization') if auth_header and auth_header.startswith('Bearer'): token = auth_header.split(" ")[1] decoded = verify_token(token) if decoded: user_roles = decoded.get("resource_access", {}).get(env_provider.KC_CLIENT_ID, {}).get("roles", []) is_admin = 'admin' in user_roles return is_admin def require_auth(roles=[]): def decorator(func): @wraps(func) def wrapper(*args, **kwargs): if request.method == "OPTIONS": return '', 200 auth_header = request.headers.get('Authorization') api_key_header = request.headers.get('X-API-Key') if auth_header and api_key_header: return jsonify({"error": "Cannot use both Bearer token and API Key authentication"}), 403 if api_key_header: api_key = get_api_key(api_key_header) if not api_key: return jsonify({"error": "Invalid API key"}), 401 expire_time = api_key.expire.replace(tzinfo=UTC) if api_key.expire.tzinfo is None else api_key.expire if datetime.now(UTC) > expire_time: return jsonify({"error": "API key has expired"}), 401 if roles and not (set(roles) & set(api_key.roles)): return jsonify({"error": "Forbidden, permission denied"}), 403 update_last_used(api_key) return func(*args, **kwargs) if not auth_header or not auth_header.startswith('Bearer'): return jsonify({"error": "Unauthorized"}), 401 token = auth_header.split(" ")[1] decoded = verify_token(token) if not decoded: return jsonify({"error": "Invalid or expired token"}), 401 user_roles = decoded.get("resource_access", {}).get(env_provider.KC_CLIENT_ID, {}).get("roles", []) if roles and not (set(roles) & set(user_roles)): print("auth failed") return jsonify({"error": "Forbidden, permission denied"}), 403 print("auth success") return func(*args, **kwargs) return wrapper return decorator rate_limits = {} default_rate_limit = "60 per minute" def init_rate_limits(app): global rate_limits rate_limits = { f"{rule.rule} : {method}": default_rate_limit for rule in app.url_map.iter_rules() for method in rule.methods } def get_rate_limit(): key = f"{request.path} : {request.method}" return rate_limits.get(key, default_rate_limit) limiter = Limiter( key_func=get_remote_address, default_limits=["100 per minute"] ) def register_blueprints(app): package_name = __name__ package_path = os.path.dirname(__file__) for finder, mod_name, is_pkg in pkgutil.walk_packages([package_path], package_name + "."): module = importlib.import_module(mod_name) for attr_name in dir(module): item = getattr(module, attr_name) if isinstance(item, Blueprint): if item.name in app.blueprints: continue app.register_blueprint(item) def generate_etag(data): serialized = json.dumps(data, sort_keys=True).encode('utf-8') return hashlib.md5(serialized).hexdigest() def etag_response(f): @wraps(f) def decorator(*args, **kwargs): response = f(*args, **kwargs) if response[1] in (200, 201): if isinstance(response[0], (dict, list)): etag = generate_etag(response[0]) if_none_match = request.headers.get("if_none_match") if if_none_match == etag: return jsonify({}), 200 resp = make_response(response[0], response[1]) resp.headers["ETag"] = etag return resp return response return decorator def generate_api_key(length=32): alphabet = string.ascii_letters + string.digits return ''.join(secrets.choice(alphabet) for _ in range(length)) def get_api_key(key): with get_db() as session: return session.query(APIKey).filter_by(key=key, is_active=True).first() def update_last_used(api_key): with get_db() as session: api_key.last_used_at = datetime.now(UTC) session.commit()