diff --git a/budget/forms.py b/budget/forms.py index c2455da..a097abe 100644 --- a/budget/forms.py +++ b/budget/forms.py @@ -174,3 +174,16 @@ class CreateArchiveForm(Form): name = TextField(_("Name for this archive (optional)"), validators=[]) start_date = DateField(_("Start date"), validators=[Required()]) end_date = DateField(_("End date"), validators=[Required()], default=datetime.now) + + +class ExportForm(Form): + export_type = SelectField(_("What do you want to download ?"), + validators=[Required()], + coerce=str, + choices=[("bills", _("bills")), ("transactions", _("transactions"))] + ) + export_format = SelectField(_("Export file format"), + validators=[Required()], + coerce=str, + choices=[("csv", "csv"), ("json", "json")] + ) diff --git a/budget/models.py b/budget/models.py index f894207..8f570fb 100644 --- a/budget/models.py +++ b/budget/models.py @@ -54,8 +54,20 @@ class Project(db.Model): def uses_weights(self): return len([i for i in self.members if i.weight != 1]) > 0 - def get_transactions_to_settle_bill(self): + def get_transactions_to_settle_bill(self, pretty_output=False): """Return a list of transactions that could be made to settle the bill""" + def prettify(transactions, pretty_output): + """ Return pretty transactions + """ + if not pretty_output: + return transactions + pretty_transactions = [] + for transaction in transactions: + pretty_transactions.append({'ower': transaction['ower'].name, + 'receiver': transaction['receiver'].name, + 'amount': round(transaction['amount'], 2)}) + return pretty_transactions + #cache value for better performance balance = self.balance credits, debts, transactions = [],[],[] @@ -83,7 +95,8 @@ class Project(db.Model): transactions.append({"ower": debts[0]["person"], "receiver": credits[0]["person"], "amount": credits[0]["balance"]}) debts[0]["balance"] = debts[0]["balance"] - credits[0]["balance"] del credits[0] - return transactions + + return prettify(transactions, pretty_output) def exactmatch(self, credit, debts): """Recursively try and find subsets of 'debts' whose sum is equal to credit""" @@ -114,6 +127,23 @@ class Project(db.Model): .order_by(Bill.date.desc())\ .order_by(Bill.id.desc()) + def get_pretty_bills(self, export_format="json"): + """Return a list of project's bills with pretty formatting""" + bills = self.get_bills() + pretty_bills = [] + for bill in bills: + if export_format == "json": + owers = [ower.name for ower in bill.owers] + else: + owers = ', '.join([ower.name for ower in bill.owers]) + pretty_bills.append({"what": bill.what, + "amount": round(bill.amount, 2), + "date": str(bill.date), + "payer_name": Person.query.get(bill.payer_id).name, + "payer_weight": Person.query.get(bill.payer_id).weight, + "owers": owers}) + return pretty_bills + def remove_member(self, member_id): """Remove a member from the project. diff --git a/budget/templates/edit_project.html b/budget/templates/edit_project.html index 0240eff..a5e85c3 100644 --- a/budget/templates/edit_project.html +++ b/budget/templates/edit_project.html @@ -10,6 +10,10 @@ {% block content %}

{{ _("Edit this project") }}

-{{ forms.edit_project(form) }} +{{ forms.edit_project(edit_form) }} +

+

{{ _("Download this project's data") }}

+
+{{ forms.export_project(export_form) }}
{% endblock %} diff --git a/budget/templates/forms.html b/budget/templates/forms.html index 3f8cc4a..b4fa236 100644 --- a/budget/templates/forms.html +++ b/budget/templates/forms.html @@ -150,6 +150,17 @@ {% endmacro %} +{% macro export_project(form) %} +
+ {{ form.hidden_tag() }} + {{ input(form.export_type) }} + {{ input(form.export_format) }} +
+
+ +
+{% endmacro %} + {% macro remind_password(form) %} {% include "display_errors.html" %} diff --git a/budget/tests.py b/budget/tests.py index c650c80..3a3d850 100644 --- a/budget/tests.py +++ b/budget/tests.py @@ -624,6 +624,116 @@ class BudgetTestCase(TestCase): self.assertNotEqual(0.0, rounded_amount, msg='%f is equal to zero after rounding' % t['amount']) + def test_export(self): + self.post_project("raclette") + + # add members + self.app.post("/raclette/members/add", data={'name': 'alexis', 'weight': 2}) + self.app.post("/raclette/members/add", data={'name': 'fred'}) + self.app.post("/raclette/members/add", data={'name': 'tata'}) + self.app.post("/raclette/members/add", data={'name': 'pépé'}) + + # create bills + self.app.post("/raclette/add", data={ + 'date': '2016-12-31', + 'what': u'fromage à raclette', + 'payer': 1, + 'payed_for': [1, 2, 3, 4], + 'amount': '10.0', + }) + + self.app.post("/raclette/add", data={ + 'date': '2016-12-31', + 'what': u'red wine', + 'payer': 2, + 'payed_for': [1, 3], + 'amount': '200', + }) + + self.app.post("/raclette/add", data={ + 'date': '2017-01-01', + 'what': u'refund', + 'payer': 3, + 'payed_for': [2], + 'amount': '13.33', + }) + + # generate json export of bills + resp = self.app.post("/raclette/edit", data={ + 'export_format': 'json', + 'export_type': 'bills' + }) + expected = [{u'date': u'2017-01-01', u'what': u'refund', + u'amount': 13.33, u'payer_name': u'tata', u'payer_weight': 1.0, u'owers': [u'fred']}, + {u'date': u'2016-12-31', u'what': u'red wine', + u'amount': 200.0, u'payer_name': u'fred', u'payer_weight': 1.0, u'owers': [u'alexis', u'tata']}, + {u'date': u'2016-12-31', u'what': u'fromage \xe0 raclette', + u'amount': 10.0, u'payer_name': u'alexis', u'payer_weight': 2.0, u'owers': [u'alexis', u'fred', u'tata', u'p\xe9p\xe9']}] + self.assertEqual(json.loads(resp.data), expected) + + # generate csv export of bills + resp = self.app.post("/raclette/edit", data={ + 'export_format': 'csv', + 'export_type': 'bills' + }) + expected = ["date,what,amount,payer_name,payer_weight,owers", + "2017-01-01,refund,13.33,tata,1.0,fred", + "2016-12-31,red wine,200.0,fred,1.0,\"alexis, tata\"", + "2016-12-31,fromage à raclette,10.0,alexis,2.0,\"alexis, fred, tata, pépé\""] + received_lines = resp.data.split("\n") + + for i, line in enumerate(expected): + self.assertEqual( + set(line.split(",")), + set(received_lines[i].strip("\r").split(",")) + ) + + # generate json export of transactions + resp = self.app.post("/raclette/edit", data={ + 'export_format': 'json', + 'export_type': 'transactions' + }) + expected = [{u"amount": 127.33, u"receiver": u"fred", u"ower": u"alexis"}, + {u"amount": 55.34, u"receiver": u"fred", u"ower": u"tata"}, + {u"amount": 2.00, u"receiver": u"fred", u"ower": u"p\xe9p\xe9"}] + self.assertEqual(json.loads(resp.data), expected) + + # generate csv export of transactions + resp = self.app.post("/raclette/edit", data={ + 'export_format': 'csv', + 'export_type': 'transactions' + }) + + expected = ["amount,receiver,ower", + "127.33,fred,alexis", + "55.34,fred,tata", + "2.0,fred,pépé"] + received_lines = resp.data.split("\n") + + for i, line in enumerate(expected): + self.assertEqual( + set(line.split(",")), + set(received_lines[i].strip("\r").split(",")) + ) + + # wrong export_format should return a 200 and export form + resp = self.app.post("/raclette/edit", data={ + 'export_format': 'wrong_export_format', + 'export_type': 'transactions' + }) + + self.assertEqual(resp.status_code, 200) + self.assertIn('id="export_format" name="export_format"', resp.data) + + # wrong export_type should return a 200 and export form + resp = self.app.post("/raclette/edit", data={ + 'export_format': 'json', + 'export_type': 'wrong_export_type' + }) + + self.assertEqual(resp.status_code, 200) + self.assertIn('id="export_format" name="export_format"', resp.data) + class APITestCase(TestCase): """Tests the API""" diff --git a/budget/translations/fr/LC_MESSAGES/messages.mo b/budget/translations/fr/LC_MESSAGES/messages.mo index 558d835..ef40aa5 100644 Binary files a/budget/translations/fr/LC_MESSAGES/messages.mo and b/budget/translations/fr/LC_MESSAGES/messages.mo differ diff --git a/budget/translations/fr/LC_MESSAGES/messages.po b/budget/translations/fr/LC_MESSAGES/messages.po index b7f8d63..44b63f9 100644 --- a/budget/translations/fr/LC_MESSAGES/messages.po +++ b/budget/translations/fr/LC_MESSAGES/messages.po @@ -146,6 +146,22 @@ msgstr "Date de départ" msgid "End date" msgstr "Date de fin" +#: forms.py:202 +msgid "What do you want to download ?" +msgstr "Que voulez-vous télécharger ?" + +#: forms.py:205 +msgid "bills" +msgstr "factures" + +#: forms.py:205 +msgid "transactions" +msgstr "remboursements" + +#: forms.py:206 +msgid "Export file format" +msgstr "Format du fichier d'export" + #: web.py:95 msgid "This private code is not the right one" msgstr "Le code que vous avez entré n'est pas correct" @@ -312,6 +328,14 @@ msgstr "Créer une archive" msgid "Create the archive" msgstr "Créer l'archive" +#: templates/forms.html:136 +msgid "Download this project's data" +msgstr "Télécharger les données de ce projet" + +#: templates/forms.html:136 +msgid "Download" +msgstr "Télécharger" + #: templates/home.html:8 msgid "Manage your shared
expenses, easily" msgstr "Gérez vos dépenses
partagées, facilement" diff --git a/budget/utils.py b/budget/utils.py index c849af0..8f5d3d5 100644 --- a/budget/utils.py +++ b/budget/utils.py @@ -2,8 +2,12 @@ import re import inspect from jinja2 import filters +from json import dumps from flask import redirect from werkzeug.routing import HTTPException, RoutingException +from io import BytesIO + +import csv def slugify(value): @@ -77,3 +81,30 @@ def minimal_round(*args, **kw): # return depending on it ires = int(res) return (res if res != ires else ires) + +def list_of_dicts2json(dict_to_convert): + """Take a list of dictionnaries and turns it into + a json in-memory file + """ + bytes_io = BytesIO() + bytes_io.write(dumps(dict_to_convert)) + bytes_io.seek(0) + return bytes_io + +def list_of_dicts2csv(dict_to_convert): + """Take a list of dictionnaries and turns it into + a csv in-memory file, assume all dict have the same keys + """ + bytes_io = BytesIO() + try: + csv_data = [dict_to_convert[0].keys()] + for dic in dict_to_convert: + csv_data.append([dic[h].encode('utf8') + if isinstance(dic[h], unicode) else str(dic[h]).encode('utf8') + for h in dict_to_convert[0].keys()]) + except (KeyError, IndexError): + csv_data = [] + writer = csv.writer(bytes_io) + writer.writerows(csv_data) + bytes_io.seek(0) + return bytes_io diff --git a/budget/web.py b/budget/web.py index 28ed344..1c58a62 100644 --- a/budget/web.py +++ b/budget/web.py @@ -10,7 +10,7 @@ and `add_project_id` for a quick overview) """ from flask import Blueprint, current_app, flash, g, redirect, \ - render_template, request, session, url_for + render_template, request, session, url_for, send_file from flask_mail import Mail, Message from flask_babel import get_locale, gettext as _ from smtplib import SMTPRecipientsRefused @@ -20,9 +20,9 @@ from sqlalchemy import orm # local modules from models import db, Project, Person, Bill from forms import AuthenticationForm, CreateArchiveForm, EditProjectForm, \ - InviteForm, MemberForm, PasswordReminder, ProjectForm, get_billform_for -from utils import Redirect303 - + InviteForm, MemberForm, PasswordReminder, ProjectForm, get_billform_for, \ + ExportForm +from utils import Redirect303, list_of_dicts2json, list_of_dicts2csv main = Blueprint("main", __name__) mail = Mail() @@ -197,20 +197,43 @@ def remind_password(): @main.route("//edit", methods=["GET", "POST"]) def edit_project(): - form = EditProjectForm() + edit_form = EditProjectForm() + export_form = ExportForm() if request.method == "POST": - if form.validate(): - project = form.update(g.project) + if edit_form.validate(): + project = edit_form.update(g.project) db.session.commit() session[project.id] = project.password return redirect(url_for(".list_bills")) - else: - form.name.data = g.project.name - form.password.data = g.project.password - form.contact_email.data = g.project.contact_email - return render_template("edit_project.html", form=form) + if export_form.validate(): + export_format = export_form.export_format.data + export_type = export_form.export_type.data + + if export_type == 'transactions': + export = g.project.get_transactions_to_settle_bill( + pretty_output=True) + if export_type == "bills": + export = g.project.get_pretty_bills( + export_format=export_format) + + if export_format == "json": + file2export = list_of_dicts2json(export) + if export_format == "csv": + file2export = list_of_dicts2csv(export) + + return send_file(file2export, + attachment_filename="%s-%s.%s" % + (g.project.name, export_type, export_format), + as_attachment=True + ) + else: + edit_form.name.data = g.project.name + edit_form.password.data = g.project.password + edit_form.contact_email.data = g.project.contact_email + + return render_template("edit_project.html", edit_form=edit_form, export_form=export_form) @main.route("//delete")