From 84c9a74d3380f59cdc9fda6aa5bf5fac9d619a0c Mon Sep 17 00:00:00 2001 From: Gabriel Corona Date: Sun, 2 Dec 2018 02:32:59 +0100 Subject: [PATCH] Protect against CSRF (#171) --- moulinette/interfaces/api.py | 31 +++++++++++ setup.py | 3 +- tests/test_api.py | 100 +++++++++++++++++++++++++++++++++++ 3 files changed, 133 insertions(+), 1 deletion(-) create mode 100644 tests/test_api.py diff --git a/moulinette/interfaces/api.py b/moulinette/interfaces/api.py index 4ce66294..abe3c90b 100644 --- a/moulinette/interfaces/api.py +++ b/moulinette/interfaces/api.py @@ -11,6 +11,7 @@ from gevent.queue import Queue from geventwebsocket import WebSocketError from bottle import run, request, response, Bottle, HTTPResponse +from bottle import get, post, install, abort, delete, put from moulinette import msignals, m18n, DATA_DIR from moulinette.core import MoulinetteError, clean_session @@ -26,6 +27,35 @@ logger = log.getLogger('moulinette.interface.api') # API helpers ---------------------------------------------------------- +CSRF_TYPES = set(["text/plain", + "application/x-www-form-urlencoded", + "multipart/form-data"]) + + +def is_csrf(): + """Checks is this is a CSRF request.""" + + if request.method != "POST": + return False + if request.content_type is None: + return True + content_type = request.content_type.lower().split(';')[0] + if content_type not in CSRF_TYPES: + return False + + return request.headers.get("X-Requested-With") is None + + +# Protection against CSRF +def filter_csrf(callback): + def wrapper(*args, **kwargs): + if is_csrf(): + abort(403, "CSRF protection") + else: + return callback(*args, **kwargs) + return wrapper + + class LogQueues(dict): """Map of session id to queue.""" pass @@ -722,6 +752,7 @@ class Interface(BaseInterface): return callback # Install plugins + app.install(filter_csrf) app.install(apiheader) app.install(api18n) app.install(_ActionsMapPlugin(actionsmap, use_websocket, log_queues)) diff --git a/setup.py b/setup.py index b9dddbaa..ea4ded50 100755 --- a/setup.py +++ b/setup.py @@ -30,5 +30,6 @@ setup(name='Moulinette', 'moulinette.interfaces', 'moulinette.utils', ], - data_files=[(LOCALES_DIR, locale_files)] + data_files=[(LOCALES_DIR, locale_files)], + tests_require=["pytest", "webtest"], ) diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 00000000..955fa577 --- /dev/null +++ b/tests/test_api.py @@ -0,0 +1,100 @@ +# -*- coding: utf-8 -*- + +from webtest import TestApp as WebTestApp +from bottle import Bottle +from moulinette.interfaces.api import filter_csrf + + +URLENCODED = 'application/x-www-form-urlencoded' +FORMDATA = 'multipart/form-data' +TEXT = 'text/plain' + +TYPES = [URLENCODED, FORMDATA, TEXT] +SAFE_METHODS = ["HEAD", "GET", "PUT", "DELETE"] + + +app = Bottle(autojson=True) +app.install(filter_csrf) + + +@app.get('/') +def get_hello(): + return "Hello World!\n" + + +@app.post('/') +def post_hello(): + return "OK\n" + + +@app.put('/') +def put_hello(): + return "OK\n" + + +@app.delete('/') +def delete_hello(): + return "OK\n" + + +webtest = WebTestApp(app) + + +def test_get(): + r = webtest.get("/") + assert r.status_code == 200 + + +def test_csrf_post(): + r = webtest.post("/", "test", expect_errors=True) + assert r.status_code == 403 + + +def test_post_json(): + r = webtest.post("/", "test", + headers=[("Content-Type", "application/json")]) + assert r.status_code == 200 + + +def test_csrf_post_text(): + r = webtest.post("/", "test", + headers=[("Content-Type", "text/plain")], + expect_errors=True) + assert r.status_code == 403 + + +def test_csrf_post_urlencoded(): + r = webtest.post("/", "test", + headers=[("Content-Type", + "application/x-www-form-urlencoded")], + expect_errors=True) + assert r.status_code == 403 + + +def test_csrf_post_form(): + r = webtest.post("/", "test", + headers=[("Content-Type", "multipart/form-data")], + expect_errors=True) + assert r.status_code == 403 + + +def test_ok_post_text(): + r = webtest.post("/", "test", + headers=[("Content-Type", "text/plain"), + ("X-Requested-With", "XMLHttpRequest")]) + assert r.status_code == 200 + + +def test_ok_post_urlencoded(): + r = webtest.post("/", "test", + headers=[("Content-Type", + "application/x-www-form-urlencoded"), + ("X-Requested-With", "XMLHttpRequest")]) + assert r.status_code == 200 + + +def test_ok_post_form(): + r = webtest.post("/", "test", + headers=[("Content-Type", "multipart/form-data"), + ("X-Requested-With", "XMLHttpRequest")]) + assert r.status_code == 200