Compare commits

...

1 Commits

Author SHA1 Message Date
48dd59f8e4 kc token public key/token issue, path root set to 1 2024-12-06 10:04:03 +00:00
9 changed files with 139 additions and 82 deletions

View File

@@ -1,63 +1,101 @@
#api/__init__.py #api/__init__.py
import base64
import os import os
from functools import wraps from functools import wraps
from flask import jsonify, session, Blueprint, request, g
from cryptography import x509
from cryptography.hazmat.primitives import serialization
from flask import jsonify, Blueprint, request
from flask_limiter import Limiter from flask_limiter import Limiter
from flask_limiter.util import get_remote_address from flask_limiter.util import get_remote_address
from jwt import decode, ExpiredSignatureError, InvalidTokenError from jwt import decode, ExpiredSignatureError, InvalidTokenError, get_unverified_header
import importlib import importlib
import requests import requests
from threading import Lock from threading import Lock
import env_provider import env_provider
_public_key_cache = None _public_key_cache = {}
_lock = Lock() _lock = Lock()
def keycloak_public_key():
global _public_key_cache
if _public_key_cache:
return _public_key_cache
with _lock:
if _public_key_cache:
return _public_key_cache
url = f"{env_provider.KC_HOST}/realms/{env_provider.KC_REALM}/protocol/openid-connect/certs" def x5c_to_public_key(x5c):
response = requests.get(url) cert_der = base64.b64decode(x5c)
jwks = response.json() cert = x509.load_der_x509_certificate(cert_der)
public_key = jwks["keys"][0]["x5c"][0] public_key = cert.public_key()
_public_key_cache = f"-----BEGIN CERTIFICATE-----\n{public_key}\n-----END CERTIFICATE-----" pem = public_key.public_bytes(
return _public_key_cache encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
)
return pem
def get_jwks():
url = f"{env_provider.KC_HOST}/realms/{env_provider.KC_REALM}/protocol/openid-connect/certs"
response = requests.get(url)
jwks = response.json()
return jwks
def get_public_key_for_kid(kid):
global _public_key_cache
if kid in _public_key_cache:
return _public_key_cache[kid]
jwks = get_jwks()
res = []
for key_data in jwks["keys"]:
if key_data["kid"] == kid and key_data["use"] == "sig" and key_data["alg"] == "RS256" and key_data["kty"] == "RSA":
x5c = key_data["x5c"][0]
pem_public_key = x5c_to_public_key(x5c)
_public_key_cache[kid] = pem_public_key
res.append(pem_public_key)
if len(res) > 0:
print(len(res))
return res[0]
return None
def verify_token(token): def verify_token(token):
try: try:
header = get_unverified_header(token)
kid = header.get("kid")
if not kid:
return None
public_key = get_public_key_for_kid(kid)
if not public_key:
return None
decoded = decode( decoded = decode(
token, token,
keycloak_public_key(), public_key,
algorithms=["RS256"], algorithms=["RS256"],
audience=env_provider.KC_CLIENT_ID audience=env_provider.KC_CLIENT_ID
) )
return decoded return decoded
except ExpiredSignatureError: except ExpiredSignatureError as e:
print(e)
return None return None
except InvalidTokenError: except InvalidTokenError as e:
print(e)
return None return None
def require_auth(roles=[]): def require_auth(roles=[]):
def decorator(func): def decorator(func):
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
if request.method == "OPTIONS":
return '', 200
auth_header = request.headers.get('Authorization') auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith('Bearer'): if not auth_header or not auth_header.startswith('Bearer'):
return jsonify({"error": "Unauthorized"}), 401 return jsonify({"error": "Unauthorized"}), 401
token = auth_header.split(" ")[1] token = auth_header.split(" ")[1]
decoded = verify_token(token) decoded = verify_token(token)
if not decoded: if not decoded:
return jsonify({"error": "Invalid or expired token"}), 401 return jsonify({"error": "Invalid or expired token"}), 401
user_roles = decoded.get("roles", []) user_roles = decoded.get("resource_access", {}).get(env_provider.KC_CLIENT_ID, {}).get("roles", [])
if roles and not set(roles).issubset(set(user_roles)): if roles and not (set(roles) & set(user_roles)):
print("auth failed")
return jsonify({"error": "Forbidden, permission denied"}), 403 return jsonify({"error": "Forbidden, permission denied"}), 403
g.user = decoded print("auth success")
return func(*args, **kwargs) return func(*args, **kwargs)
return wrapper return wrapper
return decorator return decorator

View File

@@ -13,8 +13,8 @@ def get_logs():
page = int(request.args.get('page', 1)) page = int(request.args.get('page', 1))
per_page = int(request.args.get('per_page', 10)) per_page = int(request.args.get('per_page', 10))
with get_db() as db: with get_db() as session:
query = db.query(Log) query = session.query(Log)
if level: if level:
query = query.filter(Log.level == level) query = query.filter(Log.level == level)
if application: if application:

View File

@@ -14,15 +14,15 @@ markdown_bp = Blueprint('markdown', __name__, url_prefix='/api/markdown')
@markdown_bp.route('/', methods=['GET']) @markdown_bp.route('/', methods=['GET'])
@limiter.limit('5 per minute') @limiter.limit('5 per minute')
def get_markdowns(): def get_markdowns():
with get_db() as db: with get_db() as session:
mds = db.query(Markdown).all() mds = session.query(Markdown).all()
return jsonify([md.to_dict() for md in mds]), 200 return jsonify([md.to_dict() for md in mds]), 200
@markdown_bp.route('/by_path/<int:path_id>', methods=['GET']) @markdown_bp.route('/by_path/<int:path_id>', methods=['GET'])
@limiter.limit('5 per minute') @limiter.limit('5 per minute')
def get_markdowns_by_path(path_id): def get_markdowns_by_path(path_id):
with get_db() as db: with get_db() as session:
markdowns = db.query(Markdown).filter(Markdown.path_id == path_id).all() markdowns = session.query(Markdown).filter(Markdown.path_id == path_id).all()
return jsonify([md.to_dict() for md in markdowns]), 200 return jsonify([md.to_dict() for md in markdowns]), 200
@@ -30,8 +30,8 @@ def get_markdowns_by_path(path_id):
@markdown_bp.route('/<int:markdown_id>', methods=['GET']) @markdown_bp.route('/<int:markdown_id>', methods=['GET'])
@limiter.limit('120 per minute') @limiter.limit('120 per minute')
def get_markdown(markdown_id): def get_markdown(markdown_id):
with get_db() as db: with get_db() as session:
markdown = db.query(Markdown).get(markdown_id) markdown = session.query(Markdown).get(markdown_id)
if markdown is None: if markdown is None:
return jsonify({"error": "file not found"}), 404 return jsonify({"error": "file not found"}), 404
return jsonify(markdown.to_dict()) return jsonify(markdown.to_dict())
@@ -47,42 +47,42 @@ def create_markdown():
if not title or not content: if not title or not content:
return jsonify({"error": "missing required fields"}), 400 return jsonify({"error": "missing required fields"}), 400
new_markdown = Markdown(title=title, content=content, path_id=path_id) new_markdown = Markdown(title=title, content=content, path_id=path_id)
with get_db() as db: with get_db() as session:
try: try:
db.add(new_markdown) session.add(new_markdown)
db.commit() session.commit()
return jsonify(new_markdown.to_dict()), 201 return jsonify(new_markdown.to_dict()), 201
except Exception as e: except Exception as e:
logger.error(f"failed to create markdown: {e}") logger.error(f"failed to create markdown: {e}")
errno = RequestContext.get_error_id() errno = RequestContext.get_error_id()
db.rollback() session.rollback()
return jsonify({"error": f"create failed - {errno}"}), 500 return jsonify({"error": f"create failed - {errno}"}), 500
@markdown_bp.route('/<int:markdown_id>', methods=['PUT']) @markdown_bp.route('/<int:markdown_id>', methods=['PUT'])
@require_auth(roles=['admin', 'creator']) @require_auth(roles=['admin', 'creator'])
@limiter.limit('20 per minute') @limiter.limit('20 per minute')
def update_markdown(markdown_id): def update_markdown(markdown_id):
with get_db() as db: with get_db() as session:
markdown = db.query(Markdown).get(markdown_id) markdown = session.query(Markdown).get(markdown_id)
if markdown is None: if markdown is None:
return jsonify({"error": "file not found"}), 404 return jsonify({"error": "file not found"}), 404
data = request.json data = request.json
markdown.title = data.get('title') markdown.title = data.get('title')
markdown.content = data.get('content') markdown.content = data.get('content')
markdown.path_id = data.get('path_id') markdown.path_id = data.get('path_id')
db.commit() session.commit()
return jsonify(markdown.to_dict()), 200 return jsonify(markdown.to_dict()), 200
@markdown_bp.route('/<int:markdown_id>', methods=['DELETE']) @markdown_bp.route('/<int:markdown_id>', methods=['DELETE'])
@require_auth(roles=['admin']) @require_auth(roles=['admin'])
@limiter.limit('20 per minute') @limiter.limit('20 per minute')
def delete_markdown(markdown_id): def delete_markdown(markdown_id):
with get_db() as db: with get_db() as session:
markdown = db.query(Markdown).get(markdown_id) markdown = session.query(Markdown).get(markdown_id)
if markdown is None: if markdown is None:
logger.error(f"failed to delete markdown: {markdown_id}") logger.error(f"failed to delete markdown: {markdown_id}")
errno = RequestContext.get_error_id() errno = RequestContext.get_error_id()
return jsonify({"error": f"file not found - {errno}"}), 404 return jsonify({"error": f"file not found - {errno}"}), 404
db.delete(markdown) session.delete(markdown)
db.commit() session.commit()
return jsonify({"message": "deleted"}), 200 return jsonify({"message": "deleted"}), 200

View File

@@ -12,15 +12,15 @@ path_bp = Blueprint('path', __name__, url_prefix='/api/path')
@path_bp.route('/', methods=['GET']) @path_bp.route('/', methods=['GET'])
@limiter.limit('5 per minute') @limiter.limit('5 per minute')
def get_root_paths(): def get_root_paths():
with get_db() as db: with get_db() as session:
paths = db.query(Path).filter(Path.parent_id == 0) paths = session.query(Path).filter(Path.parent_id == 1)
return jsonify([pth.to_dict() for pth in paths]), 200 return jsonify([pth.to_dict() for pth in paths]), 200
@path_bp.route('/<int:path_id>', methods=['GET']) @path_bp.route('/<int:path_id>', methods=['GET'])
@limiter.limit('5 per minute') @limiter.limit('5 per minute')
def get_path(path_id): def get_path(path_id):
with get_db() as db: with get_db() as session:
path = db.query(Path).get(path_id) path = session.query(Path).get(path_id)
if path is None: if path is None:
return jsonify({"error": "file not found"}), 404 return jsonify({"error": "file not found"}), 404
return jsonify(path.to_dict()), 200 return jsonify(path.to_dict()), 200
@@ -28,8 +28,8 @@ def get_path(path_id):
@path_bp.route('/parent/<int:parent_id>', methods=['GET']) @path_bp.route('/parent/<int:parent_id>', methods=['GET'])
@limiter.limit('5 per minute') @limiter.limit('5 per minute')
def get_path_by_parent(parent_id): def get_path_by_parent(parent_id):
with get_db() as db: with get_db() as session:
paths = db.query(Path).filter(Path.parent_id == parent_id).all() paths = session.query(Path).filter(Path.parent_id == parent_id).all()
return jsonify([pth.to_dict() for pth in paths]), 200 return jsonify([pth.to_dict() for pth in paths]), 200
@path_bp.route('/', methods=['POST']) @path_bp.route('/', methods=['POST'])
@@ -39,15 +39,14 @@ def create_path():
data = request.json data = request.json
if not data or 'name' not in data or 'parent_id' not in data: if not data or 'name' not in data or 'parent_id' not in data:
return jsonify({"error": "bad request"}), 400 return jsonify({"error": "bad request"}), 400
with get_db() as db: with get_db() as session:
if data['parent_id'] != 0 and not db.query(Path).get(data['parent_id']): if data['parent_id'] != 0 and not session.query(Path).get(data['parent_id']):
return jsonify({"error": "path not found"}), 404 return jsonify({"error": "path not found"}), 404
if db.query(Path).filter_by(name=data['name'], parent_id=data['parent_id']).first(): if session.query(Path).filter_by(name=data['name'], parent_id=data['parent_id']).first():
return jsonify({"error": "Path already exists under the parent"}), 409 return jsonify({"error": "Path already exists under the parent"}), 409
new_path = Path(name=data['name'], parent_id=data['parent_id']) new_path = Path(name=data['name'], parent_id=data['parent_id'])
db.add(new_path) session.add(new_path)
db.commit() session.commit()
return jsonify(new_path.to_dict()), 201 return jsonify(new_path.to_dict()), 201
@path_bp.route('/<int:path_id>', methods=['PUT']) @path_bp.route('/<int:path_id>', methods=['PUT'])
@@ -57,30 +56,30 @@ def update_path(path_id):
data = request.json data = request.json
if not data or 'name' not in data or 'parent_id' not in data: if not data or 'name' not in data or 'parent_id' not in data:
return jsonify({"error": "bad request"}), 400 return jsonify({"error": "bad request"}), 400
with get_db() as db: with get_db() as session:
path = db.query(Path).get(path_id) path = session.query(Path).get(path_id)
if path is None: if path is None:
return jsonify({"error": "path not found"}), 404 return jsonify({"error": "path not found"}), 404
if db.query(Path).filter_by(name=data['name'], parent_id=data['parent_id']).first(): if session.query(Path).filter_by(name=data['name'], parent_id=data['parent_id']).first():
return jsonify({"error": "Path already exists under the parent"}), 409 return jsonify({"error": "Path already exists under the parent"}), 409
path.name = data['name'] path.name = data['name']
path.parent_id = data['parent_id'] path.parent_id = data['parent_id']
db.commit() session.commit()
return jsonify(path.to_dict()), 200 return jsonify(path.to_dict()), 200
@path_bp.route('/<int:path_id>', methods=['DELETE']) @path_bp.route('/<int:path_id>', methods=['DELETE'])
@limiter.limit('60 per minute') @limiter.limit('60 per minute')
@require_auth(roles=['admin']) @require_auth(roles=['admin'])
def delete_path(path_id): def delete_path(path_id):
with get_db() as db: with get_db() as session:
path = db.query(Path).get(path_id) path = session.query(Path).get(path_id)
if not path: if not path:
return jsonify({"error": "path not found"}), 404 return jsonify({"error": "path not found"}), 404
if db.query(Path).filter_by(parent_id=path_id).first(): if session.query(Path).filter_by(parent_id=path_id).first():
return jsonify({"error": "can not delete non empty path"}), 409 return jsonify({"error": "can not delete non empty path"}), 409
if db.query(Markdown).filter_by(path_id=path_id).first(): if session.query(Markdown).filter_by(path_id=path_id).first():
return jsonify({"error": "can not delete non empty path"}), 409 return jsonify({"error": "can not delete non empty path"}), 409
db.delete(path) session.delete(path)
db.commit() session.commit()
return jsonify({"message": "path deleted"}), 200 return jsonify({"message": "path deleted"}), 200

View File

@@ -11,8 +11,8 @@ logger = logging.getLogger(__name__)
@resource_bp.route('/<identifier>', methods=['GET']) @resource_bp.route('/<identifier>', methods=['GET'])
@limiter.limit('10 per second') @limiter.limit('10 per second')
def get_resource(identifier): def get_resource(identifier):
with get_db() as db: with get_db() as session:
resource = db.query(Resource).get(identifier) resource = session.query(Resource).get(identifier)
if resource is None: if resource is None:
logger.error(f"resource not found: {identifier}") logger.error(f"resource not found: {identifier}")
errno = RequestContext.get_error_id() errno = RequestContext.get_error_id()
@@ -30,13 +30,13 @@ def create_resource():
if not identifier or not content or not data_type: if not identifier or not content or not data_type:
return jsonify({"error": "missing required fields"}), 400 return jsonify({"error": "missing required fields"}), 400
resource_entry = Resource(id=identifier, content=content, data_type=data_type) resource_entry = Resource(id=identifier, content=content, data_type=data_type)
with get_db() as db: with get_db() as session:
try: try:
db.add(resource_entry) session.add(resource_entry)
db.commit() session.commit()
return jsonify(resource_entry.to_dict()), 201 return jsonify(resource_entry.to_dict()), 201
except Exception as e: except Exception as e:
db.rollback() session.rollback()
logger.error(f"Failed to create resource: {e}") logger.error(f"Failed to create resource: {e}")
errno = RequestContext.get_error_id() errno = RequestContext.get_error_id()
return jsonify({"error": f"failed to create resource - {errno}"}), 500 return jsonify({"error": f"failed to create resource - {errno}"}), 500
@@ -45,12 +45,12 @@ def create_resource():
@require_auth(roles=["admin"]) @require_auth(roles=["admin"])
@limiter.limit('20 per minute') @limiter.limit('20 per minute')
def delete_resource(identifier): def delete_resource(identifier):
with get_db() as db: with get_db() as session:
resource = db.query(Resource).get(identifier) resource = session.query(Resource).get(identifier)
if not resource: if not resource:
logger.error(f"resource not found: {identifier}") logger.error(f"resource not found: {identifier}")
errno = RequestContext.get_error_id() errno = RequestContext.get_error_id()
return jsonify({"error": f"Resource not found - {errno}"}), 404 return jsonify({"error": f"Resource not found - {errno}"}), 404
db.delete(resource) session.delete(resource)
db.commit() session.commit()
return jsonify({"message": "Resource deleted"}), 200 return jsonify({"message": "Resource deleted"}), 200

22
app.py
View File

@@ -1,6 +1,9 @@
# app.py # app.py
from pprint import pprint from pprint import pprint
from sqlalchemy import text
from db.models.Path import Path
from logging_handlers.DatabaseLogHandler import DatabaseLogHandler from logging_handlers.DatabaseLogHandler import DatabaseLogHandler
from urllib.parse import urlparse from urllib.parse import urlparse
from api import limiter from api import limiter
@@ -24,8 +27,25 @@ try:
db.create_all() db.create_all()
except Exception as e: except Exception as e:
print(f"db not ready {e}") print(f"db not ready {e}")
try:
with db.get_db() as session:
root_path = session.query(Path).filter(Path.id == 1).first()
if not root_path:
session.execute(text("SET FOREIGN_KEY_CHECKS=0;"))
#session.execute(text("ALTER TABLE path AUTO_INCREMENT = 0;"))
session.execute(text("ALTER TABLE path MODIFY COLUMN id INT;"))
root_path = Path(id=1, name="")
session.add(root_path)
session.commit()
session.execute(text("ALTER TABLE path MODIFY COLUMN id INT AUTO_INCREMENT;"))
session.execute(text("SET FOREIGN_KEY_CHECKS=1;"))
logger.info("Root path created")
except Exception as e:
logger.error(f"Failed to create root path {e}")
app = Flask(__name__) app = Flask(__name__)
app.config['SERVER_NAME'] = env_provider.BACKEND_HOST #app.config['SERVER_NAME'] = env_provider.BACKEND_HOST
app.secret_key = env_provider.SESSION_SECRET_KEY app.secret_key = env_provider.SESSION_SECRET_KEY
CORS(app, resources={r"/api/*": {"origins": [ CORS(app, resources={r"/api/*": {"origins": [
env_provider.KC_HOST, env_provider.KC_HOST,

View File

@@ -1,5 +1,5 @@
#db/models/Path.py #db/models/Path.py
from sqlalchemy import Column, Text, LargeBinary, String, Integer, ForeignKey, UniqueConstraint from sqlalchemy import Column, String, Integer, ForeignKey, UniqueConstraint
from db.models import Base from db.models import Base
@@ -7,7 +7,7 @@ class Path(Base):
__tablename__ = "path" __tablename__ = "path"
id = Column(Integer, primary_key=True, autoincrement=True) id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String(50), nullable=False) name = Column(String(50), nullable=False)
parent_id = Column(Integer, ForeignKey("path.id")) parent_id = Column(Integer, ForeignKey("path.id"), nullable=True)
__table_args__ = (UniqueConstraint("parent_id", "name", name="unique_parent_id_name"), ) __table_args__ = (UniqueConstraint("parent_id", "name", name="unique_parent_id_name"), )
def to_dict(self): def to_dict(self):
return { return {

View File

@@ -2,6 +2,6 @@
from db import get_db from db import get_db
def insert_log(log_entry): def insert_log(log_entry):
with get_db() as db: with get_db() as session:
db.add(log_entry) session.add(log_entry)
db.commit() session.commit()

View File

@@ -15,9 +15,9 @@ class DatabaseLogHandler(logging.Handler):
log_entry = Log(message=message, level=level, application=self.application, extra=extra) log_entry = Log(message=message, level=level, application=self.application, extra=extra)
print(message) print(message)
try: try:
with get_db() as db: with get_db() as session:
db.add(log_entry) session.add(log_entry)
db.commit() session.commit()
RequestContext.set_error_id(log_entry.id) RequestContext.set_error_id(log_entry.id)
except Exception: except Exception:
print(f"failed to log") print(f"failed to log")