Compare commits

...

1 Commits

Author SHA1 Message Date
48dd59f8e4 kc token public key/token issue, path root set to 1 2024-12-06 10:04:03 +00:00
9 changed files with 139 additions and 82 deletions

View File

@@ -1,63 +1,101 @@
#api/__init__.py
import base64
import os
from functools import wraps
from flask import jsonify, session, Blueprint, request, g
from cryptography import x509
from cryptography.hazmat.primitives import serialization
from flask import jsonify, Blueprint, request
from flask_limiter import Limiter
from flask_limiter.util import get_remote_address
from jwt import decode, ExpiredSignatureError, InvalidTokenError
from jwt import decode, ExpiredSignatureError, InvalidTokenError, get_unverified_header
import importlib
import requests
from threading import Lock
import env_provider
_public_key_cache = None
_public_key_cache = {}
_lock = Lock()
def keycloak_public_key():
global _public_key_cache
if _public_key_cache:
return _public_key_cache
with _lock:
if _public_key_cache:
return _public_key_cache
url = f"{env_provider.KC_HOST}/realms/{env_provider.KC_REALM}/protocol/openid-connect/certs"
response = requests.get(url)
jwks = response.json()
public_key = jwks["keys"][0]["x5c"][0]
_public_key_cache = f"-----BEGIN CERTIFICATE-----\n{public_key}\n-----END CERTIFICATE-----"
return _public_key_cache
def x5c_to_public_key(x5c):
cert_der = base64.b64decode(x5c)
cert = x509.load_der_x509_certificate(cert_der)
public_key = cert.public_key()
pem = public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
)
return pem
def get_jwks():
url = f"{env_provider.KC_HOST}/realms/{env_provider.KC_REALM}/protocol/openid-connect/certs"
response = requests.get(url)
jwks = response.json()
return jwks
def get_public_key_for_kid(kid):
global _public_key_cache
if kid in _public_key_cache:
return _public_key_cache[kid]
jwks = get_jwks()
res = []
for key_data in jwks["keys"]:
if key_data["kid"] == kid and key_data["use"] == "sig" and key_data["alg"] == "RS256" and key_data["kty"] == "RSA":
x5c = key_data["x5c"][0]
pem_public_key = x5c_to_public_key(x5c)
_public_key_cache[kid] = pem_public_key
res.append(pem_public_key)
if len(res) > 0:
print(len(res))
return res[0]
return None
def verify_token(token):
try:
header = get_unverified_header(token)
kid = header.get("kid")
if not kid:
return None
public_key = get_public_key_for_kid(kid)
if not public_key:
return None
decoded = decode(
token,
keycloak_public_key(),
public_key,
algorithms=["RS256"],
audience=env_provider.KC_CLIENT_ID
)
return decoded
except ExpiredSignatureError:
except ExpiredSignatureError as e:
print(e)
return None
except InvalidTokenError:
except InvalidTokenError as e:
print(e)
return None
def require_auth(roles=[]):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
if request.method == "OPTIONS":
return '', 200
auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith('Bearer'):
return jsonify({"error": "Unauthorized"}), 401
token = auth_header.split(" ")[1]
decoded = verify_token(token)
if not decoded:
return jsonify({"error": "Invalid or expired token"}), 401
user_roles = decoded.get("roles", [])
if roles and not set(roles).issubset(set(user_roles)):
user_roles = decoded.get("resource_access", {}).get(env_provider.KC_CLIENT_ID, {}).get("roles", [])
if roles and not (set(roles) & set(user_roles)):
print("auth failed")
return jsonify({"error": "Forbidden, permission denied"}), 403
g.user = decoded
print("auth success")
return func(*args, **kwargs)
return wrapper
return decorator

View File

@@ -13,8 +13,8 @@ def get_logs():
page = int(request.args.get('page', 1))
per_page = int(request.args.get('per_page', 10))
with get_db() as db:
query = db.query(Log)
with get_db() as session:
query = session.query(Log)
if level:
query = query.filter(Log.level == level)
if application:

View File

@@ -14,15 +14,15 @@ markdown_bp = Blueprint('markdown', __name__, url_prefix='/api/markdown')
@markdown_bp.route('/', methods=['GET'])
@limiter.limit('5 per minute')
def get_markdowns():
with get_db() as db:
mds = db.query(Markdown).all()
with get_db() as session:
mds = session.query(Markdown).all()
return jsonify([md.to_dict() for md in mds]), 200
@markdown_bp.route('/by_path/<int:path_id>', methods=['GET'])
@limiter.limit('5 per minute')
def get_markdowns_by_path(path_id):
with get_db() as db:
markdowns = db.query(Markdown).filter(Markdown.path_id == path_id).all()
with get_db() as session:
markdowns = session.query(Markdown).filter(Markdown.path_id == path_id).all()
return jsonify([md.to_dict() for md in markdowns]), 200
@@ -30,8 +30,8 @@ def get_markdowns_by_path(path_id):
@markdown_bp.route('/<int:markdown_id>', methods=['GET'])
@limiter.limit('120 per minute')
def get_markdown(markdown_id):
with get_db() as db:
markdown = db.query(Markdown).get(markdown_id)
with get_db() as session:
markdown = session.query(Markdown).get(markdown_id)
if markdown is None:
return jsonify({"error": "file not found"}), 404
return jsonify(markdown.to_dict())
@@ -47,42 +47,42 @@ def create_markdown():
if not title or not content:
return jsonify({"error": "missing required fields"}), 400
new_markdown = Markdown(title=title, content=content, path_id=path_id)
with get_db() as db:
with get_db() as session:
try:
db.add(new_markdown)
db.commit()
session.add(new_markdown)
session.commit()
return jsonify(new_markdown.to_dict()), 201
except Exception as e:
logger.error(f"failed to create markdown: {e}")
errno = RequestContext.get_error_id()
db.rollback()
session.rollback()
return jsonify({"error": f"create failed - {errno}"}), 500
@markdown_bp.route('/<int:markdown_id>', methods=['PUT'])
@require_auth(roles=['admin', 'creator'])
@limiter.limit('20 per minute')
def update_markdown(markdown_id):
with get_db() as db:
markdown = db.query(Markdown).get(markdown_id)
with get_db() as session:
markdown = session.query(Markdown).get(markdown_id)
if markdown is None:
return jsonify({"error": "file not found"}), 404
data = request.json
markdown.title = data.get('title')
markdown.content = data.get('content')
markdown.path_id = data.get('path_id')
db.commit()
session.commit()
return jsonify(markdown.to_dict()), 200
@markdown_bp.route('/<int:markdown_id>', methods=['DELETE'])
@require_auth(roles=['admin'])
@limiter.limit('20 per minute')
def delete_markdown(markdown_id):
with get_db() as db:
markdown = db.query(Markdown).get(markdown_id)
with get_db() as session:
markdown = session.query(Markdown).get(markdown_id)
if markdown is None:
logger.error(f"failed to delete markdown: {markdown_id}")
errno = RequestContext.get_error_id()
return jsonify({"error": f"file not found - {errno}"}), 404
db.delete(markdown)
db.commit()
session.delete(markdown)
session.commit()
return jsonify({"message": "deleted"}), 200

View File

@@ -12,15 +12,15 @@ path_bp = Blueprint('path', __name__, url_prefix='/api/path')
@path_bp.route('/', methods=['GET'])
@limiter.limit('5 per minute')
def get_root_paths():
with get_db() as db:
paths = db.query(Path).filter(Path.parent_id == 0)
with get_db() as session:
paths = session.query(Path).filter(Path.parent_id == 1)
return jsonify([pth.to_dict() for pth in paths]), 200
@path_bp.route('/<int:path_id>', methods=['GET'])
@limiter.limit('5 per minute')
def get_path(path_id):
with get_db() as db:
path = db.query(Path).get(path_id)
with get_db() as session:
path = session.query(Path).get(path_id)
if path is None:
return jsonify({"error": "file not found"}), 404
return jsonify(path.to_dict()), 200
@@ -28,8 +28,8 @@ def get_path(path_id):
@path_bp.route('/parent/<int:parent_id>', methods=['GET'])
@limiter.limit('5 per minute')
def get_path_by_parent(parent_id):
with get_db() as db:
paths = db.query(Path).filter(Path.parent_id == parent_id).all()
with get_db() as session:
paths = session.query(Path).filter(Path.parent_id == parent_id).all()
return jsonify([pth.to_dict() for pth in paths]), 200
@path_bp.route('/', methods=['POST'])
@@ -39,15 +39,14 @@ def create_path():
data = request.json
if not data or 'name' not in data or 'parent_id' not in data:
return jsonify({"error": "bad request"}), 400
with get_db() as db:
if data['parent_id'] != 0 and not db.query(Path).get(data['parent_id']):
with get_db() as session:
if data['parent_id'] != 0 and not session.query(Path).get(data['parent_id']):
return jsonify({"error": "path not found"}), 404
if db.query(Path).filter_by(name=data['name'], parent_id=data['parent_id']).first():
if session.query(Path).filter_by(name=data['name'], parent_id=data['parent_id']).first():
return jsonify({"error": "Path already exists under the parent"}), 409
new_path = Path(name=data['name'], parent_id=data['parent_id'])
db.add(new_path)
db.commit()
session.add(new_path)
session.commit()
return jsonify(new_path.to_dict()), 201
@path_bp.route('/<int:path_id>', methods=['PUT'])
@@ -57,30 +56,30 @@ def update_path(path_id):
data = request.json
if not data or 'name' not in data or 'parent_id' not in data:
return jsonify({"error": "bad request"}), 400
with get_db() as db:
path = db.query(Path).get(path_id)
with get_db() as session:
path = session.query(Path).get(path_id)
if path is None:
return jsonify({"error": "path not found"}), 404
if db.query(Path).filter_by(name=data['name'], parent_id=data['parent_id']).first():
if session.query(Path).filter_by(name=data['name'], parent_id=data['parent_id']).first():
return jsonify({"error": "Path already exists under the parent"}), 409
path.name = data['name']
path.parent_id = data['parent_id']
db.commit()
session.commit()
return jsonify(path.to_dict()), 200
@path_bp.route('/<int:path_id>', methods=['DELETE'])
@limiter.limit('60 per minute')
@require_auth(roles=['admin'])
def delete_path(path_id):
with get_db() as db:
path = db.query(Path).get(path_id)
with get_db() as session:
path = session.query(Path).get(path_id)
if not path:
return jsonify({"error": "path not found"}), 404
if db.query(Path).filter_by(parent_id=path_id).first():
if session.query(Path).filter_by(parent_id=path_id).first():
return jsonify({"error": "can not delete non empty path"}), 409
if db.query(Markdown).filter_by(path_id=path_id).first():
if session.query(Markdown).filter_by(path_id=path_id).first():
return jsonify({"error": "can not delete non empty path"}), 409
db.delete(path)
db.commit()
session.delete(path)
session.commit()
return jsonify({"message": "path deleted"}), 200

View File

@@ -11,8 +11,8 @@ logger = logging.getLogger(__name__)
@resource_bp.route('/<identifier>', methods=['GET'])
@limiter.limit('10 per second')
def get_resource(identifier):
with get_db() as db:
resource = db.query(Resource).get(identifier)
with get_db() as session:
resource = session.query(Resource).get(identifier)
if resource is None:
logger.error(f"resource not found: {identifier}")
errno = RequestContext.get_error_id()
@@ -30,13 +30,13 @@ def create_resource():
if not identifier or not content or not data_type:
return jsonify({"error": "missing required fields"}), 400
resource_entry = Resource(id=identifier, content=content, data_type=data_type)
with get_db() as db:
with get_db() as session:
try:
db.add(resource_entry)
db.commit()
session.add(resource_entry)
session.commit()
return jsonify(resource_entry.to_dict()), 201
except Exception as e:
db.rollback()
session.rollback()
logger.error(f"Failed to create resource: {e}")
errno = RequestContext.get_error_id()
return jsonify({"error": f"failed to create resource - {errno}"}), 500
@@ -45,12 +45,12 @@ def create_resource():
@require_auth(roles=["admin"])
@limiter.limit('20 per minute')
def delete_resource(identifier):
with get_db() as db:
resource = db.query(Resource).get(identifier)
with get_db() as session:
resource = session.query(Resource).get(identifier)
if not resource:
logger.error(f"resource not found: {identifier}")
errno = RequestContext.get_error_id()
return jsonify({"error": f"Resource not found - {errno}"}), 404
db.delete(resource)
db.commit()
session.delete(resource)
session.commit()
return jsonify({"message": "Resource deleted"}), 200

22
app.py
View File

@@ -1,6 +1,9 @@
# app.py
from pprint import pprint
from sqlalchemy import text
from db.models.Path import Path
from logging_handlers.DatabaseLogHandler import DatabaseLogHandler
from urllib.parse import urlparse
from api import limiter
@@ -24,8 +27,25 @@ try:
db.create_all()
except Exception as e:
print(f"db not ready {e}")
try:
with db.get_db() as session:
root_path = session.query(Path).filter(Path.id == 1).first()
if not root_path:
session.execute(text("SET FOREIGN_KEY_CHECKS=0;"))
#session.execute(text("ALTER TABLE path AUTO_INCREMENT = 0;"))
session.execute(text("ALTER TABLE path MODIFY COLUMN id INT;"))
root_path = Path(id=1, name="")
session.add(root_path)
session.commit()
session.execute(text("ALTER TABLE path MODIFY COLUMN id INT AUTO_INCREMENT;"))
session.execute(text("SET FOREIGN_KEY_CHECKS=1;"))
logger.info("Root path created")
except Exception as e:
logger.error(f"Failed to create root path {e}")
app = Flask(__name__)
app.config['SERVER_NAME'] = env_provider.BACKEND_HOST
#app.config['SERVER_NAME'] = env_provider.BACKEND_HOST
app.secret_key = env_provider.SESSION_SECRET_KEY
CORS(app, resources={r"/api/*": {"origins": [
env_provider.KC_HOST,

View File

@@ -1,5 +1,5 @@
#db/models/Path.py
from sqlalchemy import Column, Text, LargeBinary, String, Integer, ForeignKey, UniqueConstraint
from sqlalchemy import Column, String, Integer, ForeignKey, UniqueConstraint
from db.models import Base
@@ -7,7 +7,7 @@ class Path(Base):
__tablename__ = "path"
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String(50), nullable=False)
parent_id = Column(Integer, ForeignKey("path.id"))
parent_id = Column(Integer, ForeignKey("path.id"), nullable=True)
__table_args__ = (UniqueConstraint("parent_id", "name", name="unique_parent_id_name"), )
def to_dict(self):
return {

View File

@@ -2,6 +2,6 @@
from db import get_db
def insert_log(log_entry):
with get_db() as db:
db.add(log_entry)
db.commit()
with get_db() as session:
session.add(log_entry)
session.commit()

View File

@@ -15,9 +15,9 @@ class DatabaseLogHandler(logging.Handler):
log_entry = Log(message=message, level=level, application=self.application, extra=extra)
print(message)
try:
with get_db() as db:
db.add(log_entry)
db.commit()
with get_db() as session:
session.add(log_entry)
session.commit()
RequestContext.set_error_id(log_entry.id)
except Exception:
print(f"failed to log")