diff --git a/README.md b/README.md index 630e763..eccf0d7 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,7 @@ DEBUG=True PROJECT_NAME=YunoHost DOMAIN=http://localhost:8000 STATIC_DIR=assets +SECRET_CSRF_KEY=TO_CHANGE # Stripe keys STRIPE_PUBLISHABLE_KEY=pk_test_gOgGjacs9YfvDJY03BRZ576O diff --git a/assets/index.js b/assets/index.js index a3eb997..c97a962 100644 --- a/assets/index.js +++ b/assets/index.js @@ -27,6 +27,7 @@ submitBtn.addEventListener('click', function (evt) { 'Content-Type': 'application/json', }, body: JSON.stringify({ + user_csrf: window.config.csrf, quantity: quantity, currency: currency, frequency: frequency diff --git a/requirements.txt b/requirements.txt index eead772..2585fb6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,3 +12,4 @@ stripe==2.47.0 toml==0.9.6 urllib3==1.25.3 Werkzeug==1.0.1 +flask-simple-csrf diff --git a/server.py b/server.py index f8529e9..e464d56 100644 --- a/server.py +++ b/server.py @@ -11,8 +11,10 @@ import json import os from flask import Flask, render_template, jsonify, request, send_from_directory +from flask_simple_csrf import CSRF from dotenv import load_dotenv, find_dotenv + # Setup Stripe python client library. load_dotenv(find_dotenv()) @@ -23,7 +25,14 @@ static_dir = str(os.path.abspath(os.path.join( __file__, "..", os.getenv("STATIC_DIR")))) app = Flask(__name__, static_folder=static_dir, static_url_path="", template_folder=static_dir) +CSRF = CSRF(config=os.getenv('CSRF_CONFIG')) +app = CSRF.init_app(app) +@app.before_request +def before_request(): + if 'CSRF_TOKEN' not in session or 'USER_CSRF' not in session: + session['USER_CSRF'] = random_string(64) + session['CSRF_TOKEN'] = CSRF.create(session['USER_CSRF']) @app.route('/', methods=['GET']) def get_index(): @@ -35,6 +44,7 @@ def get_publishable_key(): return jsonify({ 'publicKey': os.getenv('STRIPE_PUBLISHABLE_KEY'), 'name': os.getenv('PROJECT_NAME'), + 'csrf': session['USER_CSRF'], }) @app.route('/create-checkout-session', methods=['POST']) @@ -42,7 +52,8 @@ def create_checkout_session(): data = json.loads(request.data) domain_url = os.getenv('DOMAIN') try: - if data['frequency'] not in ['RECURING', 'ONE_TIME'] or + if CSRF.verify(data['user_csrf'], session['CSRF_TOKEN']) is False or + data['frequency'] not in ['RECURING', 'ONE_TIME'] or data['currency'] not in ['EUR', 'USD'] or int(data['quantity']) <= 0: return jsonify(error="Bad value"), 400