api limiter & auto bp register

This commit is contained in:
h z
2024-12-03 11:28:40 +00:00
parent a93bd5d870
commit e929f67f4e
6 changed files with 62 additions and 24 deletions

View File

@@ -1,7 +1,11 @@
#api/__init__.py #api/__init__.py
import os
from functools import wraps from functools import wraps
from flask import jsonify, session from flask import jsonify, session, Blueprint
from flask_limiter import Limiter
from flask_limiter.util import get_remote_address
import importlib
def require_auth(roles=[]): def require_auth(roles=[]):
@@ -15,4 +19,23 @@ def require_auth(roles=[]):
return jsonify({"error": "Forbidden, permission denied"}), 403 return jsonify({"error": "Forbidden, permission denied"}), 403
return func(*args, **kwargs) return func(*args, **kwargs)
return wrapper return wrapper
return decorator return decorator
limiter = Limiter(
key_func=get_remote_address,
default_limits=["100 per minute"]
)
def register_blueprints(app):
current_dir = os.path.dirname(__file__)
for filename in os.listdir(current_dir):
if filename == "__init__.py" or not filename.endswith(".py"):
continue
module_name = filename[:-3]
module = importlib.import_module(f"api.{module_name}")
for attr in dir(module):
bp = getattr(module, attr)
if isinstance(bp, Blueprint):
app.register_blueprint(bp)

View File

@@ -4,6 +4,7 @@ from authlib.integrations.flask_client import OAuth
from contexts.RequestContext import RequestContext from contexts.RequestContext import RequestContext
import env_provider import env_provider
import logging import logging
from api import limiter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
auth_bp = Blueprint('auth', __name__, url_prefix='/api') auth_bp = Blueprint('auth', __name__, url_prefix='/api')
@@ -17,11 +18,13 @@ keycloak = oauth.register(
) )
@auth_bp.route('/login', methods=['GET']) @auth_bp.route('/login', methods=['GET'])
@limiter.limit("20 per minute")
def login(): def login():
redirect_uri = url_for("auth.authorize", _external=True) redirect_uri = url_for("auth.authorize", _external=True)
return keycloak.authorize_redirect(redirect_uri) return keycloak.authorize_redirect(redirect_uri)
@auth_bp.route('/authorize', methods=['GET']) @auth_bp.route('/authorize', methods=['GET'])
@limiter.limit("20 per minute")
def authorize(): def authorize():
try: try:
token = keycloak.authorize_access_token() token = keycloak.authorize_access_token()
@@ -41,6 +44,7 @@ def logout():
return redirect(logout_url) return redirect(logout_url)
@auth_bp.route("/user", methods=["GET"]) @auth_bp.route("/user", methods=["GET"])
@limiter.limit("80 per minute")
def user(): def user():
u = session.get('user') u = session.get('user')
if not u: if not u:

View File

@@ -6,11 +6,22 @@ from contexts.RequestContext import RequestContext
from db import get_db from db import get_db
from db.models.Markdown import Markdown from db.models.Markdown import Markdown
import logging import logging
from api import limiter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
markdown_bp = Blueprint('markdown', __name__, url_prefix='/api/markdown') markdown_bp = Blueprint('markdown', __name__, url_prefix='/api/markdown')
@markdown_bp.route('/', methods=['GET'])
@limiter.limit('5 per minute')
def get_markdowns():
with get_db() as db:
mds = db.query(Markdown).all()
return jsonify([md.to_dict() for md in mds])
@markdown_bp.route('/<int:markdown_id>', methods=['GET']) @markdown_bp.route('/<int:markdown_id>', methods=['GET'])
@limiter.limit('120 per minute')
def get_markdown(markdown_id): def get_markdown(markdown_id):
with get_db() as db: with get_db() as db:
markdown = db.query(Markdown).get(markdown_id) markdown = db.query(Markdown).get(markdown_id)
@@ -20,6 +31,7 @@ def get_markdown(markdown_id):
@markdown_bp.route('/', methods=['POST']) @markdown_bp.route('/', methods=['POST'])
@require_auth(roles=['admin', 'creator']) @require_auth(roles=['admin', 'creator'])
@limiter.limit('20 per minute')
def create_markdown(): def create_markdown():
data = request.json data = request.json
title = data.get('title') title = data.get('title')
@@ -41,6 +53,7 @@ def create_markdown():
@markdown_bp.route('/<int:markdown_id>', methods=['PUT']) @markdown_bp.route('/<int:markdown_id>', methods=['PUT'])
@require_auth(roles=['admin', 'creator']) @require_auth(roles=['admin', 'creator'])
@limiter.limit('20 per minute')
def update_markdown(markdown_id): def update_markdown(markdown_id):
with get_db() as db: with get_db() as db:
markdown = db.query(Markdown).get(markdown_id) markdown = db.query(Markdown).get(markdown_id)
@@ -55,6 +68,7 @@ def update_markdown(markdown_id):
@markdown_bp.route('/<int:markdown_id>', methods=['DELETE']) @markdown_bp.route('/<int:markdown_id>', methods=['DELETE'])
@require_auth(roles=['admin']) @require_auth(roles=['admin'])
@limiter.limit('20 per minute')
def delete_markdown(markdown_id): def delete_markdown(markdown_id):
with get_db() as db: with get_db() as db:
markdown = db.query(Markdown).get(markdown_id) markdown = db.query(Markdown).get(markdown_id)

View File

@@ -4,11 +4,12 @@ from flask import Blueprint, jsonify, request
from contexts.RequestContext import RequestContext from contexts.RequestContext import RequestContext
from db import get_db from db import get_db
from db.models.Resource import Resource from db.models.Resource import Resource
from api import require_auth from api import require_auth, limiter
import logging import logging
resource_bp = Blueprint('resource', __name__, url_prefix='/api/resource') resource_bp = Blueprint('resource', __name__, url_prefix='/api/resource')
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@resource_bp.route('/<identifier>', methods=['GET']) @resource_bp.route('/<identifier>', methods=['GET'])
@limiter.limit('10 per second')
def get_resource(identifier): def get_resource(identifier):
with get_db() as db: with get_db() as db:
resource = db.query(Resource).get(identifier) resource = db.query(Resource).get(identifier)
@@ -20,6 +21,7 @@ def get_resource(identifier):
@resource_bp.route('/', methods=['POST']) @resource_bp.route('/', methods=['POST'])
@require_auth(roles=["admin", "creator"]) @require_auth(roles=["admin", "creator"])
@limiter.limit('20 per minute')
def create_resource(): def create_resource():
data = request.get_json() data = request.get_json()
identifier = data.get('id') identifier = data.get('id')
@@ -41,6 +43,7 @@ def create_resource():
@resource_bp.route('/<identifier>', methods=['DELETE']) @resource_bp.route('/<identifier>', methods=['DELETE'])
@require_auth(roles=["admin"]) @require_auth(roles=["admin"])
@limiter.limit('20 per minute')
def delete_resource(identifier): def delete_resource(identifier):
with get_db() as db: with get_db() as db:
resource = db.query(Resource).get(identifier) resource = db.query(Resource).get(identifier)

18
app.py
View File

@@ -1,18 +1,14 @@
# app.py # app.py
from logging_handlers.DatabaseLogHandler import DatabaseLogHandler
from urllib.parse import urlparse from urllib.parse import urlparse
from api import limiter
from flask import Flask, request from flask import Flask, request
from flask_cors import CORS from flask_cors import CORS
import api
import env_provider import env_provider
import db import db
from api.auth import auth_bp
from api.log import logs_bp
from api.markdown import markdown_bp
import logging import logging
from api.resource import resource_bp
from logging_handlers.DatabaseLogHandler import DatabaseLogHandler
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
db_handler = DatabaseLogHandler(application="backend") db_handler = DatabaseLogHandler(application="backend")
@@ -37,10 +33,10 @@ app = Flask(__name__)
app.secret_key = env_provider.SESSION_SECRET_KEY app.secret_key = env_provider.SESSION_SECRET_KEY
CORS(app, resources={r"/api/*": {"origins": is_allowed_origin}}) CORS(app, resources={r"/api/*": {"origins": is_allowed_origin}})
app.register_blueprint(markdown_bp) limiter.init_app(app)
app.register_blueprint(auth_bp)
app.register_blueprint(logs_bp) api.register_blueprints(app)
app.register_blueprint(resource_bp)
@app.before_request @app.before_request
def log_request(): def log_request():
if request.path.startswith("/api/log"): if request.path.startswith("/api/log"):

View File

@@ -1,12 +1,10 @@
import threading from flask import g
class RequestContext: class RequestContext:
_thread_local = threading.local() @staticmethod
def set_error_id(error_id):
g.error_id = error_id
@classmethod @staticmethod
def set_error_id(cls, error_id): def get_error_id():
cls._thread_local.error_id = error_id return getattr(g, "error_id", None)
@classmethod
def get_error_id(cls):
return getattr(cls._thread_local, "error_id", None)