diff --git a/api/__init__.py b/api/__init__.py index 9f759ed..e3b800e 100644 --- a/api/__init__.py +++ b/api/__init__.py @@ -1,7 +1,11 @@ #api/__init__.py - +import os from functools import wraps -from flask import jsonify, session +from flask import jsonify, session, Blueprint +from flask_limiter import Limiter +from flask_limiter.util import get_remote_address + +import importlib def require_auth(roles=[]): @@ -15,4 +19,23 @@ def require_auth(roles=[]): return jsonify({"error": "Forbidden, permission denied"}), 403 return func(*args, **kwargs) return wrapper - return decorator \ No newline at end of file + return decorator + + +limiter = Limiter( + key_func=get_remote_address, + default_limits=["100 per minute"] +) + + +def register_blueprints(app): + current_dir = os.path.dirname(__file__) + for filename in os.listdir(current_dir): + if filename == "__init__.py" or not filename.endswith(".py"): + continue + module_name = filename[:-3] + module = importlib.import_module(f"api.{module_name}") + for attr in dir(module): + bp = getattr(module, attr) + if isinstance(bp, Blueprint): + app.register_blueprint(bp) \ No newline at end of file diff --git a/api/auth.py b/api/auth.py index eee4307..20fa6d1 100644 --- a/api/auth.py +++ b/api/auth.py @@ -4,6 +4,7 @@ from authlib.integrations.flask_client import OAuth from contexts.RequestContext import RequestContext import env_provider import logging +from api import limiter logger = logging.getLogger(__name__) auth_bp = Blueprint('auth', __name__, url_prefix='/api') @@ -17,11 +18,13 @@ keycloak = oauth.register( ) @auth_bp.route('/login', methods=['GET']) +@limiter.limit("20 per minute") def login(): redirect_uri = url_for("auth.authorize", _external=True) return keycloak.authorize_redirect(redirect_uri) @auth_bp.route('/authorize', methods=['GET']) +@limiter.limit("20 per minute") def authorize(): try: token = keycloak.authorize_access_token() @@ -41,6 +44,7 @@ def logout(): return redirect(logout_url) @auth_bp.route("/user", methods=["GET"]) +@limiter.limit("80 per minute") def user(): u = session.get('user') if not u: diff --git a/api/markdown.py b/api/markdown.py index e74c907..053e848 100644 --- a/api/markdown.py +++ b/api/markdown.py @@ -6,11 +6,22 @@ from contexts.RequestContext import RequestContext from db import get_db from db.models.Markdown import Markdown import logging +from api import limiter logger = logging.getLogger(__name__) markdown_bp = Blueprint('markdown', __name__, url_prefix='/api/markdown') +@markdown_bp.route('/', methods=['GET']) +@limiter.limit('5 per minute') +def get_markdowns(): + with get_db() as db: + mds = db.query(Markdown).all() + return jsonify([md.to_dict() for md in mds]) + + + @markdown_bp.route('/', methods=['GET']) +@limiter.limit('120 per minute') def get_markdown(markdown_id): with get_db() as db: markdown = db.query(Markdown).get(markdown_id) @@ -20,6 +31,7 @@ def get_markdown(markdown_id): @markdown_bp.route('/', methods=['POST']) @require_auth(roles=['admin', 'creator']) +@limiter.limit('20 per minute') def create_markdown(): data = request.json title = data.get('title') @@ -41,6 +53,7 @@ def create_markdown(): @markdown_bp.route('/', methods=['PUT']) @require_auth(roles=['admin', 'creator']) +@limiter.limit('20 per minute') def update_markdown(markdown_id): with get_db() as db: markdown = db.query(Markdown).get(markdown_id) @@ -55,6 +68,7 @@ def update_markdown(markdown_id): @markdown_bp.route('/', methods=['DELETE']) @require_auth(roles=['admin']) +@limiter.limit('20 per minute') def delete_markdown(markdown_id): with get_db() as db: markdown = db.query(Markdown).get(markdown_id) diff --git a/api/resource.py b/api/resource.py index 3f0a9e5..44f3e9c 100644 --- a/api/resource.py +++ b/api/resource.py @@ -4,11 +4,12 @@ from flask import Blueprint, jsonify, request from contexts.RequestContext import RequestContext from db import get_db from db.models.Resource import Resource -from api import require_auth +from api import require_auth, limiter import logging resource_bp = Blueprint('resource', __name__, url_prefix='/api/resource') logger = logging.getLogger(__name__) @resource_bp.route('/', methods=['GET']) +@limiter.limit('10 per second') def get_resource(identifier): with get_db() as db: resource = db.query(Resource).get(identifier) @@ -20,6 +21,7 @@ def get_resource(identifier): @resource_bp.route('/', methods=['POST']) @require_auth(roles=["admin", "creator"]) +@limiter.limit('20 per minute') def create_resource(): data = request.get_json() identifier = data.get('id') @@ -41,6 +43,7 @@ def create_resource(): @resource_bp.route('/', methods=['DELETE']) @require_auth(roles=["admin"]) +@limiter.limit('20 per minute') def delete_resource(identifier): with get_db() as db: resource = db.query(Resource).get(identifier) diff --git a/app.py b/app.py index c49b384..f008f1c 100644 --- a/app.py +++ b/app.py @@ -1,18 +1,14 @@ # app.py +from logging_handlers.DatabaseLogHandler import DatabaseLogHandler from urllib.parse import urlparse - +from api import limiter from flask import Flask, request from flask_cors import CORS +import api import env_provider import db -from api.auth import auth_bp -from api.log import logs_bp -from api.markdown import markdown_bp import logging -from api.resource import resource_bp -from logging_handlers.DatabaseLogHandler import DatabaseLogHandler - logger = logging.getLogger(__name__) db_handler = DatabaseLogHandler(application="backend") @@ -37,10 +33,10 @@ app = Flask(__name__) app.secret_key = env_provider.SESSION_SECRET_KEY CORS(app, resources={r"/api/*": {"origins": is_allowed_origin}}) -app.register_blueprint(markdown_bp) -app.register_blueprint(auth_bp) -app.register_blueprint(logs_bp) -app.register_blueprint(resource_bp) +limiter.init_app(app) + +api.register_blueprints(app) + @app.before_request def log_request(): if request.path.startswith("/api/log"): diff --git a/contexts/RequestContext.py b/contexts/RequestContext.py index 7b6e16b..0a01bd5 100644 --- a/contexts/RequestContext.py +++ b/contexts/RequestContext.py @@ -1,12 +1,10 @@ -import threading +from flask import g class RequestContext: - _thread_local = threading.local() + @staticmethod + def set_error_id(error_id): + g.error_id = error_id - @classmethod - def set_error_id(cls, error_id): - cls._thread_local.error_id = error_id - - @classmethod - def get_error_id(cls): - return getattr(cls._thread_local, "error_id", None) + @staticmethod + def get_error_id(): + return getattr(g, "error_id", None)