From 48dd59f8e4a1ef85f07ab05dd47f037d03e9b4d8 Mon Sep 17 00:00:00 2001 From: hzhang Date: Fri, 6 Dec 2024 10:04:03 +0000 Subject: [PATCH] kc token public key/token issue, path root set to 1 --- api/__init__.py | 82 +++++++++++++++++++------- api/log.py | 4 +- api/markdown.py | 34 +++++------ api/path.py | 43 +++++++------- api/resource.py | 20 +++---- app.py | 22 ++++++- db/models/Path.py | 4 +- db/utils.py | 6 +- logging_handlers/DatabaseLogHandler.py | 6 +- 9 files changed, 139 insertions(+), 82 deletions(-) diff --git a/api/__init__.py b/api/__init__.py index 710361a..93ebc91 100644 --- a/api/__init__.py +++ b/api/__init__.py @@ -1,63 +1,101 @@ #api/__init__.py +import base64 import os from functools import wraps -from flask import jsonify, session, Blueprint, request, g + +from cryptography import x509 +from cryptography.hazmat.primitives import serialization +from flask import jsonify, Blueprint, request from flask_limiter import Limiter from flask_limiter.util import get_remote_address -from jwt import decode, ExpiredSignatureError, InvalidTokenError +from jwt import decode, ExpiredSignatureError, InvalidTokenError, get_unverified_header import importlib import requests from threading import Lock import env_provider -_public_key_cache = None +_public_key_cache = {} _lock = Lock() -def keycloak_public_key(): - global _public_key_cache - if _public_key_cache: - return _public_key_cache - with _lock: - if _public_key_cache: - return _public_key_cache - url = f"{env_provider.KC_HOST}/realms/{env_provider.KC_REALM}/protocol/openid-connect/certs" - response = requests.get(url) - jwks = response.json() - public_key = jwks["keys"][0]["x5c"][0] - _public_key_cache = f"-----BEGIN CERTIFICATE-----\n{public_key}\n-----END CERTIFICATE-----" - return _public_key_cache +def x5c_to_public_key(x5c): + cert_der = base64.b64decode(x5c) + cert = x509.load_der_x509_certificate(cert_der) + public_key = cert.public_key() + pem = public_key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo + ) + return pem + +def get_jwks(): + url = f"{env_provider.KC_HOST}/realms/{env_provider.KC_REALM}/protocol/openid-connect/certs" + response = requests.get(url) + jwks = response.json() + return jwks + +def get_public_key_for_kid(kid): + global _public_key_cache + if kid in _public_key_cache: + return _public_key_cache[kid] + jwks = get_jwks() + res = [] + for key_data in jwks["keys"]: + if key_data["kid"] == kid and key_data["use"] == "sig" and key_data["alg"] == "RS256" and key_data["kty"] == "RSA": + x5c = key_data["x5c"][0] + pem_public_key = x5c_to_public_key(x5c) + _public_key_cache[kid] = pem_public_key + res.append(pem_public_key) + if len(res) > 0: + print(len(res)) + return res[0] + + return None + def verify_token(token): try: + header = get_unverified_header(token) + kid = header.get("kid") + if not kid: + return None + public_key = get_public_key_for_kid(kid) + if not public_key: + return None decoded = decode( token, - keycloak_public_key(), + public_key, algorithms=["RS256"], audience=env_provider.KC_CLIENT_ID ) return decoded - except ExpiredSignatureError: + except ExpiredSignatureError as e: + print(e) return None - except InvalidTokenError: + except InvalidTokenError as e: + print(e) return None def require_auth(roles=[]): def decorator(func): @wraps(func) def wrapper(*args, **kwargs): + if request.method == "OPTIONS": + return '', 200 auth_header = request.headers.get('Authorization') if not auth_header or not auth_header.startswith('Bearer'): return jsonify({"error": "Unauthorized"}), 401 token = auth_header.split(" ")[1] + decoded = verify_token(token) if not decoded: return jsonify({"error": "Invalid or expired token"}), 401 - user_roles = decoded.get("roles", []) - if roles and not set(roles).issubset(set(user_roles)): + user_roles = decoded.get("resource_access", {}).get(env_provider.KC_CLIENT_ID, {}).get("roles", []) + if roles and not (set(roles) & set(user_roles)): + print("auth failed") return jsonify({"error": "Forbidden, permission denied"}), 403 - g.user = decoded + print("auth success") return func(*args, **kwargs) return wrapper return decorator diff --git a/api/log.py b/api/log.py index 4513b8b..39afb25 100644 --- a/api/log.py +++ b/api/log.py @@ -13,8 +13,8 @@ def get_logs(): page = int(request.args.get('page', 1)) per_page = int(request.args.get('per_page', 10)) - with get_db() as db: - query = db.query(Log) + with get_db() as session: + query = session.query(Log) if level: query = query.filter(Log.level == level) if application: diff --git a/api/markdown.py b/api/markdown.py index 0312b5d..0697a5e 100644 --- a/api/markdown.py +++ b/api/markdown.py @@ -14,15 +14,15 @@ markdown_bp = Blueprint('markdown', __name__, url_prefix='/api/markdown') @markdown_bp.route('/', methods=['GET']) @limiter.limit('5 per minute') def get_markdowns(): - with get_db() as db: - mds = db.query(Markdown).all() + with get_db() as session: + mds = session.query(Markdown).all() return jsonify([md.to_dict() for md in mds]), 200 @markdown_bp.route('/by_path/', methods=['GET']) @limiter.limit('5 per minute') def get_markdowns_by_path(path_id): - with get_db() as db: - markdowns = db.query(Markdown).filter(Markdown.path_id == path_id).all() + with get_db() as session: + markdowns = session.query(Markdown).filter(Markdown.path_id == path_id).all() return jsonify([md.to_dict() for md in markdowns]), 200 @@ -30,8 +30,8 @@ def get_markdowns_by_path(path_id): @markdown_bp.route('/', methods=['GET']) @limiter.limit('120 per minute') def get_markdown(markdown_id): - with get_db() as db: - markdown = db.query(Markdown).get(markdown_id) + with get_db() as session: + markdown = session.query(Markdown).get(markdown_id) if markdown is None: return jsonify({"error": "file not found"}), 404 return jsonify(markdown.to_dict()) @@ -47,42 +47,42 @@ def create_markdown(): if not title or not content: return jsonify({"error": "missing required fields"}), 400 new_markdown = Markdown(title=title, content=content, path_id=path_id) - with get_db() as db: + with get_db() as session: try: - db.add(new_markdown) - db.commit() + session.add(new_markdown) + session.commit() return jsonify(new_markdown.to_dict()), 201 except Exception as e: logger.error(f"failed to create markdown: {e}") errno = RequestContext.get_error_id() - db.rollback() + session.rollback() return jsonify({"error": f"create failed - {errno}"}), 500 @markdown_bp.route('/', methods=['PUT']) @require_auth(roles=['admin', 'creator']) @limiter.limit('20 per minute') def update_markdown(markdown_id): - with get_db() as db: - markdown = db.query(Markdown).get(markdown_id) + with get_db() as session: + markdown = session.query(Markdown).get(markdown_id) if markdown is None: return jsonify({"error": "file not found"}), 404 data = request.json markdown.title = data.get('title') markdown.content = data.get('content') markdown.path_id = data.get('path_id') - db.commit() + session.commit() return jsonify(markdown.to_dict()), 200 @markdown_bp.route('/', methods=['DELETE']) @require_auth(roles=['admin']) @limiter.limit('20 per minute') def delete_markdown(markdown_id): - with get_db() as db: - markdown = db.query(Markdown).get(markdown_id) + with get_db() as session: + markdown = session.query(Markdown).get(markdown_id) if markdown is None: logger.error(f"failed to delete markdown: {markdown_id}") errno = RequestContext.get_error_id() return jsonify({"error": f"file not found - {errno}"}), 404 - db.delete(markdown) - db.commit() + session.delete(markdown) + session.commit() return jsonify({"message": "deleted"}), 200 \ No newline at end of file diff --git a/api/path.py b/api/path.py index 67a5f1b..8ba774c 100644 --- a/api/path.py +++ b/api/path.py @@ -12,15 +12,15 @@ path_bp = Blueprint('path', __name__, url_prefix='/api/path') @path_bp.route('/', methods=['GET']) @limiter.limit('5 per minute') def get_root_paths(): - with get_db() as db: - paths = db.query(Path).filter(Path.parent_id == 0) + with get_db() as session: + paths = session.query(Path).filter(Path.parent_id == 1) return jsonify([pth.to_dict() for pth in paths]), 200 @path_bp.route('/', methods=['GET']) @limiter.limit('5 per minute') def get_path(path_id): - with get_db() as db: - path = db.query(Path).get(path_id) + with get_db() as session: + path = session.query(Path).get(path_id) if path is None: return jsonify({"error": "file not found"}), 404 return jsonify(path.to_dict()), 200 @@ -28,8 +28,8 @@ def get_path(path_id): @path_bp.route('/parent/', methods=['GET']) @limiter.limit('5 per minute') def get_path_by_parent(parent_id): - with get_db() as db: - paths = db.query(Path).filter(Path.parent_id == parent_id).all() + with get_db() as session: + paths = session.query(Path).filter(Path.parent_id == parent_id).all() return jsonify([pth.to_dict() for pth in paths]), 200 @path_bp.route('/', methods=['POST']) @@ -39,15 +39,14 @@ def create_path(): data = request.json if not data or 'name' not in data or 'parent_id' not in data: return jsonify({"error": "bad request"}), 400 - with get_db() as db: - if data['parent_id'] != 0 and not db.query(Path).get(data['parent_id']): + with get_db() as session: + if data['parent_id'] != 0 and not session.query(Path).get(data['parent_id']): return jsonify({"error": "path not found"}), 404 - if db.query(Path).filter_by(name=data['name'], parent_id=data['parent_id']).first(): + if session.query(Path).filter_by(name=data['name'], parent_id=data['parent_id']).first(): return jsonify({"error": "Path already exists under the parent"}), 409 - new_path = Path(name=data['name'], parent_id=data['parent_id']) - db.add(new_path) - db.commit() + session.add(new_path) + session.commit() return jsonify(new_path.to_dict()), 201 @path_bp.route('/', methods=['PUT']) @@ -57,30 +56,30 @@ def update_path(path_id): data = request.json if not data or 'name' not in data or 'parent_id' not in data: return jsonify({"error": "bad request"}), 400 - with get_db() as db: - path = db.query(Path).get(path_id) + with get_db() as session: + path = session.query(Path).get(path_id) if path is None: return jsonify({"error": "path not found"}), 404 - if db.query(Path).filter_by(name=data['name'], parent_id=data['parent_id']).first(): + if session.query(Path).filter_by(name=data['name'], parent_id=data['parent_id']).first(): return jsonify({"error": "Path already exists under the parent"}), 409 path.name = data['name'] path.parent_id = data['parent_id'] - db.commit() + session.commit() return jsonify(path.to_dict()), 200 @path_bp.route('/', methods=['DELETE']) @limiter.limit('60 per minute') @require_auth(roles=['admin']) def delete_path(path_id): - with get_db() as db: - path = db.query(Path).get(path_id) + with get_db() as session: + path = session.query(Path).get(path_id) if not path: return jsonify({"error": "path not found"}), 404 - if db.query(Path).filter_by(parent_id=path_id).first(): + if session.query(Path).filter_by(parent_id=path_id).first(): return jsonify({"error": "can not delete non empty path"}), 409 - if db.query(Markdown).filter_by(path_id=path_id).first(): + if session.query(Markdown).filter_by(path_id=path_id).first(): return jsonify({"error": "can not delete non empty path"}), 409 - db.delete(path) - db.commit() + session.delete(path) + session.commit() return jsonify({"message": "path deleted"}), 200 diff --git a/api/resource.py b/api/resource.py index 44f3e9c..c9b926d 100644 --- a/api/resource.py +++ b/api/resource.py @@ -11,8 +11,8 @@ logger = logging.getLogger(__name__) @resource_bp.route('/', methods=['GET']) @limiter.limit('10 per second') def get_resource(identifier): - with get_db() as db: - resource = db.query(Resource).get(identifier) + with get_db() as session: + resource = session.query(Resource).get(identifier) if resource is None: logger.error(f"resource not found: {identifier}") errno = RequestContext.get_error_id() @@ -30,13 +30,13 @@ def create_resource(): if not identifier or not content or not data_type: return jsonify({"error": "missing required fields"}), 400 resource_entry = Resource(id=identifier, content=content, data_type=data_type) - with get_db() as db: + with get_db() as session: try: - db.add(resource_entry) - db.commit() + session.add(resource_entry) + session.commit() return jsonify(resource_entry.to_dict()), 201 except Exception as e: - db.rollback() + session.rollback() logger.error(f"Failed to create resource: {e}") errno = RequestContext.get_error_id() return jsonify({"error": f"failed to create resource - {errno}"}), 500 @@ -45,12 +45,12 @@ def create_resource(): @require_auth(roles=["admin"]) @limiter.limit('20 per minute') def delete_resource(identifier): - with get_db() as db: - resource = db.query(Resource).get(identifier) + with get_db() as session: + resource = session.query(Resource).get(identifier) if not resource: logger.error(f"resource not found: {identifier}") errno = RequestContext.get_error_id() return jsonify({"error": f"Resource not found - {errno}"}), 404 - db.delete(resource) - db.commit() + session.delete(resource) + session.commit() return jsonify({"message": "Resource deleted"}), 200 \ No newline at end of file diff --git a/app.py b/app.py index 4d77b0a..352af7d 100644 --- a/app.py +++ b/app.py @@ -1,6 +1,9 @@ # app.py from pprint import pprint +from sqlalchemy import text + +from db.models.Path import Path from logging_handlers.DatabaseLogHandler import DatabaseLogHandler from urllib.parse import urlparse from api import limiter @@ -24,8 +27,25 @@ try: db.create_all() except Exception as e: print(f"db not ready {e}") + + +try: + with db.get_db() as session: + root_path = session.query(Path).filter(Path.id == 1).first() + if not root_path: + session.execute(text("SET FOREIGN_KEY_CHECKS=0;")) + #session.execute(text("ALTER TABLE path AUTO_INCREMENT = 0;")) + session.execute(text("ALTER TABLE path MODIFY COLUMN id INT;")) + root_path = Path(id=1, name="") + session.add(root_path) + session.commit() + session.execute(text("ALTER TABLE path MODIFY COLUMN id INT AUTO_INCREMENT;")) + session.execute(text("SET FOREIGN_KEY_CHECKS=1;")) + logger.info("Root path created") +except Exception as e: + logger.error(f"Failed to create root path {e}") app = Flask(__name__) -app.config['SERVER_NAME'] = env_provider.BACKEND_HOST +#app.config['SERVER_NAME'] = env_provider.BACKEND_HOST app.secret_key = env_provider.SESSION_SECRET_KEY CORS(app, resources={r"/api/*": {"origins": [ env_provider.KC_HOST, diff --git a/db/models/Path.py b/db/models/Path.py index 3c336d1..e9aa10a 100644 --- a/db/models/Path.py +++ b/db/models/Path.py @@ -1,5 +1,5 @@ #db/models/Path.py -from sqlalchemy import Column, Text, LargeBinary, String, Integer, ForeignKey, UniqueConstraint +from sqlalchemy import Column, String, Integer, ForeignKey, UniqueConstraint from db.models import Base @@ -7,7 +7,7 @@ class Path(Base): __tablename__ = "path" id = Column(Integer, primary_key=True, autoincrement=True) name = Column(String(50), nullable=False) - parent_id = Column(Integer, ForeignKey("path.id")) + parent_id = Column(Integer, ForeignKey("path.id"), nullable=True) __table_args__ = (UniqueConstraint("parent_id", "name", name="unique_parent_id_name"), ) def to_dict(self): return { diff --git a/db/utils.py b/db/utils.py index 600e4fc..5284e55 100644 --- a/db/utils.py +++ b/db/utils.py @@ -2,6 +2,6 @@ from db import get_db def insert_log(log_entry): - with get_db() as db: - db.add(log_entry) - db.commit() \ No newline at end of file + with get_db() as session: + session.add(log_entry) + session.commit() \ No newline at end of file diff --git a/logging_handlers/DatabaseLogHandler.py b/logging_handlers/DatabaseLogHandler.py index 770ede0..5c6d7c3 100644 --- a/logging_handlers/DatabaseLogHandler.py +++ b/logging_handlers/DatabaseLogHandler.py @@ -15,9 +15,9 @@ class DatabaseLogHandler(logging.Handler): log_entry = Log(message=message, level=level, application=self.application, extra=extra) print(message) try: - with get_db() as db: - db.add(log_entry) - db.commit() + with get_db() as session: + session.add(log_entry) + session.commit() RequestContext.set_error_id(log_entry.id) except Exception: print(f"failed to log") \ No newline at end of file