From eaa2350b726ea035192a9f5779fd29229fd3cb59 Mon Sep 17 00:00:00 2001 From: hzhang Date: Fri, 6 Dec 2024 16:12:44 +0000 Subject: [PATCH] add: api for rate control --- api/__init__.py | 19 ++++++++++++++++++- api/config.py | 28 ++++++++++++++++++++++++++++ api/markdown.py | 13 +++++++------ api/path.py | 14 ++++++++------ api/resource.py | 11 ++++++----- app.py | 3 +++ 6 files changed, 70 insertions(+), 18 deletions(-) create mode 100644 api/config.py diff --git a/api/__init__.py b/api/__init__.py index 93ebc91..d3aa3c9 100644 --- a/api/__init__.py +++ b/api/__init__.py @@ -100,6 +100,21 @@ def require_auth(roles=[]): return wrapper return decorator +rate_limits = {} +default_rate_limit = "60 per minute" +def init_rate_limits(app): + global rate_limits + rate_limits = { + f"{rule.rule} : {method}": default_rate_limit + for rule in app.url_map.iter_rules() + for method in rule.methods + } + + +def get_rate_limit(): + key = f"{request.path} : {request.method}" + return rate_limits.get(key, default_rate_limit) + limiter = Limiter( key_func=get_remote_address, @@ -117,4 +132,6 @@ def register_blueprints(app): for attr in dir(module): bp = getattr(module, attr) if isinstance(bp, Blueprint): - app.register_blueprint(bp) \ No newline at end of file + app.register_blueprint(bp) + + diff --git a/api/config.py b/api/config.py new file mode 100644 index 0000000..7b57f59 --- /dev/null +++ b/api/config.py @@ -0,0 +1,28 @@ +from flask import Blueprint, jsonify, request + +from api import require_auth, rate_limits +import re +config_bp = Blueprint('config', __name__, url_prefix='/api/config') + +RATE_LIMIT_REGEX = re.compile(r'^\d+\s?(per\s|/)\s?(second|minute|hour|day)$') + +def is_valid_rate_limit(limit): + return bool(RATE_LIMIT_REGEX.match(limit)) +@config_bp.route('/limits', methods=['GET']) +@require_auth(roles=['admin']) +def limits(): + return jsonify(rate_limits), 200 + +@config_bp.route('/limits', methods=['PUT']) +@require_auth(roles=['admin']) +def update_limits(): + data = request.json + if not data or 'endpoint' not in data or 'method' not in data or 'new_limit' not in data: + return jsonify({'error': 'Bad request'}), 400 + key = f"{data['endpoint']} : {data['method']}" + if key not in rate_limits: + return jsonify({'error': 'endpoint not fount'}), 404 + if is_valid_rate_limit(data['new_limit']): + rate_limits[key] = data['new_limit'] + return jsonify({"message": "updated"}), 200 + return jsonify({'error': 'Invalid value'}), 400 diff --git a/api/markdown.py b/api/markdown.py index 0697a5e..14d7091 100644 --- a/api/markdown.py +++ b/api/markdown.py @@ -1,6 +1,7 @@ #api/markdown.py from flask import Blueprint, request, jsonify +import api from api import require_auth from contexts.RequestContext import RequestContext from db import get_db @@ -12,14 +13,14 @@ logger = logging.getLogger(__name__) markdown_bp = Blueprint('markdown', __name__, url_prefix='/api/markdown') @markdown_bp.route('/', methods=['GET']) -@limiter.limit('5 per minute') +@limiter.limit(api.get_rate_limit) def get_markdowns(): with get_db() as session: mds = session.query(Markdown).all() return jsonify([md.to_dict() for md in mds]), 200 @markdown_bp.route('/by_path/', methods=['GET']) -@limiter.limit('5 per minute') +@limiter.limit(api.get_rate_limit) def get_markdowns_by_path(path_id): with get_db() as session: markdowns = session.query(Markdown).filter(Markdown.path_id == path_id).all() @@ -28,7 +29,7 @@ def get_markdowns_by_path(path_id): @markdown_bp.route('/', methods=['GET']) -@limiter.limit('120 per minute') +@limiter.limit(api.get_rate_limit) def get_markdown(markdown_id): with get_db() as session: markdown = session.query(Markdown).get(markdown_id) @@ -38,7 +39,7 @@ def get_markdown(markdown_id): @markdown_bp.route('/', methods=['POST']) @require_auth(roles=['admin', 'creator']) -@limiter.limit('20 per minute') +@limiter.limit(api.get_rate_limit) def create_markdown(): data = request.json title = data.get('title') @@ -60,7 +61,7 @@ def create_markdown(): @markdown_bp.route('/', methods=['PUT']) @require_auth(roles=['admin', 'creator']) -@limiter.limit('20 per minute') +@limiter.limit(api.get_rate_limit) def update_markdown(markdown_id): with get_db() as session: markdown = session.query(Markdown).get(markdown_id) @@ -75,7 +76,7 @@ def update_markdown(markdown_id): @markdown_bp.route('/', methods=['DELETE']) @require_auth(roles=['admin']) -@limiter.limit('20 per minute') +@limiter.limit(api.get_rate_limit) def delete_markdown(markdown_id): with get_db() as session: markdown = session.query(Markdown).get(markdown_id) diff --git a/api/path.py b/api/path.py index acf82c3..35f8e2e 100644 --- a/api/path.py +++ b/api/path.py @@ -1,4 +1,6 @@ from flask import Blueprint, request, jsonify + +import api from api import require_auth from db import get_db from db.models.Markdown import Markdown @@ -10,14 +12,14 @@ logger = logging.getLogger(__name__) path_bp = Blueprint('path', __name__, url_prefix='/api/path') @path_bp.route('/', methods=['GET']) -@limiter.limit('5 per minute') +@limiter.limit(api.get_rate_limit) def get_root_paths(): with get_db() as session: paths = session.query(Path).filter(Path.parent_id == 1) return jsonify([pth.to_dict() for pth in paths]), 200 @path_bp.route('/', methods=['GET']) -@limiter.limit('5 per minute') +@limiter.limit(api.get_rate_limit) def get_path(path_id): with get_db() as session: path = session.query(Path).get(path_id) @@ -26,14 +28,14 @@ def get_path(path_id): return jsonify(path.to_dict()), 200 @path_bp.route('/parent/', methods=['GET']) -@limiter.limit('5 per minute') +@limiter.limit(api.get_rate_limit) def get_path_by_parent(parent_id): with get_db() as session: paths = session.query(Path).filter(Path.parent_id == parent_id).all() return jsonify([pth.to_dict() for pth in paths]), 200 @path_bp.route('/', methods=['POST']) -@limiter.limit('60 per minute') +@limiter.limit(api.get_rate_limit) @require_auth(roles=['admin', 'creator']) def create_path(): data = request.json @@ -50,7 +52,7 @@ def create_path(): return jsonify(new_path.to_dict()), 201 @path_bp.route('/', methods=['PUT']) -@limiter.limit('30 per minute') +@limiter.limit(api.get_rate_limit) @require_auth(roles=['admin']) def update_path(path_id): data = request.json @@ -68,7 +70,7 @@ def update_path(path_id): return jsonify(path.to_dict()), 200 @path_bp.route('/', methods=['DELETE']) -@limiter.limit('60 per minute') +@limiter.limit(api.get_rate_limit) @require_auth(roles=['admin']) def delete_path(path_id): with get_db() as session: diff --git a/api/resource.py b/api/resource.py index c9b926d..b161a2c 100644 --- a/api/resource.py +++ b/api/resource.py @@ -1,6 +1,6 @@ #api/resource.py +import api from flask import Blueprint, jsonify, request - from contexts.RequestContext import RequestContext from db import get_db from db.models.Resource import Resource @@ -9,7 +9,7 @@ import logging resource_bp = Blueprint('resource', __name__, url_prefix='/api/resource') logger = logging.getLogger(__name__) @resource_bp.route('/', methods=['GET']) -@limiter.limit('10 per second') +@limiter.limit(api.get_rate_limit) def get_resource(identifier): with get_db() as session: resource = session.query(Resource).get(identifier) @@ -21,7 +21,7 @@ def get_resource(identifier): @resource_bp.route('/', methods=['POST']) @require_auth(roles=["admin", "creator"]) -@limiter.limit('20 per minute') +@limiter.limit(api.get_rate_limit) def create_resource(): data = request.get_json() identifier = data.get('id') @@ -43,7 +43,7 @@ def create_resource(): @resource_bp.route('/', methods=['DELETE']) @require_auth(roles=["admin"]) -@limiter.limit('20 per minute') +@limiter.limit(api.get_rate_limit) def delete_resource(identifier): with get_db() as session: resource = session.query(Resource).get(identifier) @@ -53,4 +53,5 @@ def delete_resource(identifier): return jsonify({"error": f"Resource not found - {errno}"}), 404 session.delete(resource) session.commit() - return jsonify({"message": "Resource deleted"}), 200 \ No newline at end of file + return jsonify({"message": "Resource deleted"}), 200 + \ No newline at end of file diff --git a/app.py b/app.py index ac13b72..8b77a8e 100644 --- a/app.py +++ b/app.py @@ -44,7 +44,10 @@ def log_request(): logger.info(f"Request received: {request.method} {request.path} from {request.remote_addr}") + + if __name__ == '__main__': + api.init_rate_limits(app) #logger.info("Starting app") pprint(env_provider.summerize()) app.run(host='0.0.0.0', port=5000)