diff --git a/budget/forms.py b/budget/forms.py
index c2455da..a097abe 100644
--- a/budget/forms.py
+++ b/budget/forms.py
@@ -174,3 +174,16 @@ class CreateArchiveForm(Form):
name = TextField(_("Name for this archive (optional)"), validators=[])
start_date = DateField(_("Start date"), validators=[Required()])
end_date = DateField(_("End date"), validators=[Required()], default=datetime.now)
+
+
+class ExportForm(Form):
+ export_type = SelectField(_("What do you want to download ?"),
+ validators=[Required()],
+ coerce=str,
+ choices=[("bills", _("bills")), ("transactions", _("transactions"))]
+ )
+ export_format = SelectField(_("Export file format"),
+ validators=[Required()],
+ coerce=str,
+ choices=[("csv", "csv"), ("json", "json")]
+ )
diff --git a/budget/models.py b/budget/models.py
index f894207..8f570fb 100644
--- a/budget/models.py
+++ b/budget/models.py
@@ -54,8 +54,20 @@ class Project(db.Model):
def uses_weights(self):
return len([i for i in self.members if i.weight != 1]) > 0
- def get_transactions_to_settle_bill(self):
+ def get_transactions_to_settle_bill(self, pretty_output=False):
"""Return a list of transactions that could be made to settle the bill"""
+ def prettify(transactions, pretty_output):
+ """ Return pretty transactions
+ """
+ if not pretty_output:
+ return transactions
+ pretty_transactions = []
+ for transaction in transactions:
+ pretty_transactions.append({'ower': transaction['ower'].name,
+ 'receiver': transaction['receiver'].name,
+ 'amount': round(transaction['amount'], 2)})
+ return pretty_transactions
+
#cache value for better performance
balance = self.balance
credits, debts, transactions = [],[],[]
@@ -83,7 +95,8 @@ class Project(db.Model):
transactions.append({"ower": debts[0]["person"], "receiver": credits[0]["person"], "amount": credits[0]["balance"]})
debts[0]["balance"] = debts[0]["balance"] - credits[0]["balance"]
del credits[0]
- return transactions
+
+ return prettify(transactions, pretty_output)
def exactmatch(self, credit, debts):
"""Recursively try and find subsets of 'debts' whose sum is equal to credit"""
@@ -114,6 +127,23 @@ class Project(db.Model):
.order_by(Bill.date.desc())\
.order_by(Bill.id.desc())
+ def get_pretty_bills(self, export_format="json"):
+ """Return a list of project's bills with pretty formatting"""
+ bills = self.get_bills()
+ pretty_bills = []
+ for bill in bills:
+ if export_format == "json":
+ owers = [ower.name for ower in bill.owers]
+ else:
+ owers = ', '.join([ower.name for ower in bill.owers])
+ pretty_bills.append({"what": bill.what,
+ "amount": round(bill.amount, 2),
+ "date": str(bill.date),
+ "payer_name": Person.query.get(bill.payer_id).name,
+ "payer_weight": Person.query.get(bill.payer_id).weight,
+ "owers": owers})
+ return pretty_bills
+
def remove_member(self, member_id):
"""Remove a member from the project.
diff --git a/budget/templates/edit_project.html b/budget/templates/edit_project.html
index 0240eff..a5e85c3 100644
--- a/budget/templates/edit_project.html
+++ b/budget/templates/edit_project.html
@@ -10,6 +10,10 @@
{% block content %}
{{ _("Edit this project") }}
+{{ _("Download this project's data") }}
+
{% endblock %}
diff --git a/budget/templates/forms.html b/budget/templates/forms.html
index 3f8cc4a..b4fa236 100644
--- a/budget/templates/forms.html
+++ b/budget/templates/forms.html
@@ -150,6 +150,17 @@
{% endmacro %}
+{% macro export_project(form) %}
+
+
+
+
+{% endmacro %}
+
{% macro remind_password(form) %}
{% include "display_errors.html" %}
diff --git a/budget/tests.py b/budget/tests.py
index c650c80..3a3d850 100644
--- a/budget/tests.py
+++ b/budget/tests.py
@@ -624,6 +624,116 @@ class BudgetTestCase(TestCase):
self.assertNotEqual(0.0, rounded_amount,
msg='%f is equal to zero after rounding' % t['amount'])
+ def test_export(self):
+ self.post_project("raclette")
+
+ # add members
+ self.app.post("/raclette/members/add", data={'name': 'alexis', 'weight': 2})
+ self.app.post("/raclette/members/add", data={'name': 'fred'})
+ self.app.post("/raclette/members/add", data={'name': 'tata'})
+ self.app.post("/raclette/members/add", data={'name': 'pépé'})
+
+ # create bills
+ self.app.post("/raclette/add", data={
+ 'date': '2016-12-31',
+ 'what': u'fromage à raclette',
+ 'payer': 1,
+ 'payed_for': [1, 2, 3, 4],
+ 'amount': '10.0',
+ })
+
+ self.app.post("/raclette/add", data={
+ 'date': '2016-12-31',
+ 'what': u'red wine',
+ 'payer': 2,
+ 'payed_for': [1, 3],
+ 'amount': '200',
+ })
+
+ self.app.post("/raclette/add", data={
+ 'date': '2017-01-01',
+ 'what': u'refund',
+ 'payer': 3,
+ 'payed_for': [2],
+ 'amount': '13.33',
+ })
+
+ # generate json export of bills
+ resp = self.app.post("/raclette/edit", data={
+ 'export_format': 'json',
+ 'export_type': 'bills'
+ })
+ expected = [{u'date': u'2017-01-01', u'what': u'refund',
+ u'amount': 13.33, u'payer_name': u'tata', u'payer_weight': 1.0, u'owers': [u'fred']},
+ {u'date': u'2016-12-31', u'what': u'red wine',
+ u'amount': 200.0, u'payer_name': u'fred', u'payer_weight': 1.0, u'owers': [u'alexis', u'tata']},
+ {u'date': u'2016-12-31', u'what': u'fromage \xe0 raclette',
+ u'amount': 10.0, u'payer_name': u'alexis', u'payer_weight': 2.0, u'owers': [u'alexis', u'fred', u'tata', u'p\xe9p\xe9']}]
+ self.assertEqual(json.loads(resp.data), expected)
+
+ # generate csv export of bills
+ resp = self.app.post("/raclette/edit", data={
+ 'export_format': 'csv',
+ 'export_type': 'bills'
+ })
+ expected = ["date,what,amount,payer_name,payer_weight,owers",
+ "2017-01-01,refund,13.33,tata,1.0,fred",
+ "2016-12-31,red wine,200.0,fred,1.0,\"alexis, tata\"",
+ "2016-12-31,fromage à raclette,10.0,alexis,2.0,\"alexis, fred, tata, pépé\""]
+ received_lines = resp.data.split("\n")
+
+ for i, line in enumerate(expected):
+ self.assertEqual(
+ set(line.split(",")),
+ set(received_lines[i].strip("\r").split(","))
+ )
+
+ # generate json export of transactions
+ resp = self.app.post("/raclette/edit", data={
+ 'export_format': 'json',
+ 'export_type': 'transactions'
+ })
+ expected = [{u"amount": 127.33, u"receiver": u"fred", u"ower": u"alexis"},
+ {u"amount": 55.34, u"receiver": u"fred", u"ower": u"tata"},
+ {u"amount": 2.00, u"receiver": u"fred", u"ower": u"p\xe9p\xe9"}]
+ self.assertEqual(json.loads(resp.data), expected)
+
+ # generate csv export of transactions
+ resp = self.app.post("/raclette/edit", data={
+ 'export_format': 'csv',
+ 'export_type': 'transactions'
+ })
+
+ expected = ["amount,receiver,ower",
+ "127.33,fred,alexis",
+ "55.34,fred,tata",
+ "2.0,fred,pépé"]
+ received_lines = resp.data.split("\n")
+
+ for i, line in enumerate(expected):
+ self.assertEqual(
+ set(line.split(",")),
+ set(received_lines[i].strip("\r").split(","))
+ )
+
+ # wrong export_format should return a 200 and export form
+ resp = self.app.post("/raclette/edit", data={
+ 'export_format': 'wrong_export_format',
+ 'export_type': 'transactions'
+ })
+
+ self.assertEqual(resp.status_code, 200)
+ self.assertIn('id="export_format" name="export_format"', resp.data)
+
+ # wrong export_type should return a 200 and export form
+ resp = self.app.post("/raclette/edit", data={
+ 'export_format': 'json',
+ 'export_type': 'wrong_export_type'
+ })
+
+ self.assertEqual(resp.status_code, 200)
+ self.assertIn('id="export_format" name="export_format"', resp.data)
+
class APITestCase(TestCase):
"""Tests the API"""
diff --git a/budget/translations/fr/LC_MESSAGES/messages.mo b/budget/translations/fr/LC_MESSAGES/messages.mo
index 558d835..ef40aa5 100644
Binary files a/budget/translations/fr/LC_MESSAGES/messages.mo and b/budget/translations/fr/LC_MESSAGES/messages.mo differ
diff --git a/budget/translations/fr/LC_MESSAGES/messages.po b/budget/translations/fr/LC_MESSAGES/messages.po
index b7f8d63..44b63f9 100644
--- a/budget/translations/fr/LC_MESSAGES/messages.po
+++ b/budget/translations/fr/LC_MESSAGES/messages.po
@@ -146,6 +146,22 @@ msgstr "Date de départ"
msgid "End date"
msgstr "Date de fin"
+#: forms.py:202
+msgid "What do you want to download ?"
+msgstr "Que voulez-vous télécharger ?"
+
+#: forms.py:205
+msgid "bills"
+msgstr "factures"
+
+#: forms.py:205
+msgid "transactions"
+msgstr "remboursements"
+
+#: forms.py:206
+msgid "Export file format"
+msgstr "Format du fichier d'export"
+
#: web.py:95
msgid "This private code is not the right one"
msgstr "Le code que vous avez entré n'est pas correct"
@@ -312,6 +328,14 @@ msgstr "Créer une archive"
msgid "Create the archive"
msgstr "Créer l'archive"
+#: templates/forms.html:136
+msgid "Download this project's data"
+msgstr "Télécharger les données de ce projet"
+
+#: templates/forms.html:136
+msgid "Download"
+msgstr "Télécharger"
+
#: templates/home.html:8
msgid "Manage your shared
expenses, easily"
msgstr "Gérez vos dépenses
partagées, facilement"
diff --git a/budget/utils.py b/budget/utils.py
index c849af0..8f5d3d5 100644
--- a/budget/utils.py
+++ b/budget/utils.py
@@ -2,8 +2,12 @@ import re
import inspect
from jinja2 import filters
+from json import dumps
from flask import redirect
from werkzeug.routing import HTTPException, RoutingException
+from io import BytesIO
+
+import csv
def slugify(value):
@@ -77,3 +81,30 @@ def minimal_round(*args, **kw):
# return depending on it
ires = int(res)
return (res if res != ires else ires)
+
+def list_of_dicts2json(dict_to_convert):
+ """Take a list of dictionnaries and turns it into
+ a json in-memory file
+ """
+ bytes_io = BytesIO()
+ bytes_io.write(dumps(dict_to_convert))
+ bytes_io.seek(0)
+ return bytes_io
+
+def list_of_dicts2csv(dict_to_convert):
+ """Take a list of dictionnaries and turns it into
+ a csv in-memory file, assume all dict have the same keys
+ """
+ bytes_io = BytesIO()
+ try:
+ csv_data = [dict_to_convert[0].keys()]
+ for dic in dict_to_convert:
+ csv_data.append([dic[h].encode('utf8')
+ if isinstance(dic[h], unicode) else str(dic[h]).encode('utf8')
+ for h in dict_to_convert[0].keys()])
+ except (KeyError, IndexError):
+ csv_data = []
+ writer = csv.writer(bytes_io)
+ writer.writerows(csv_data)
+ bytes_io.seek(0)
+ return bytes_io
diff --git a/budget/web.py b/budget/web.py
index 28ed344..1c58a62 100644
--- a/budget/web.py
+++ b/budget/web.py
@@ -10,7 +10,7 @@ and `add_project_id` for a quick overview)
"""
from flask import Blueprint, current_app, flash, g, redirect, \
- render_template, request, session, url_for
+ render_template, request, session, url_for, send_file
from flask_mail import Mail, Message
from flask_babel import get_locale, gettext as _
from smtplib import SMTPRecipientsRefused
@@ -20,9 +20,9 @@ from sqlalchemy import orm
# local modules
from models import db, Project, Person, Bill
from forms import AuthenticationForm, CreateArchiveForm, EditProjectForm, \
- InviteForm, MemberForm, PasswordReminder, ProjectForm, get_billform_for
-from utils import Redirect303
-
+ InviteForm, MemberForm, PasswordReminder, ProjectForm, get_billform_for, \
+ ExportForm
+from utils import Redirect303, list_of_dicts2json, list_of_dicts2csv
main = Blueprint("main", __name__)
mail = Mail()
@@ -197,20 +197,43 @@ def remind_password():
@main.route("//edit", methods=["GET", "POST"])
def edit_project():
- form = EditProjectForm()
+ edit_form = EditProjectForm()
+ export_form = ExportForm()
if request.method == "POST":
- if form.validate():
- project = form.update(g.project)
+ if edit_form.validate():
+ project = edit_form.update(g.project)
db.session.commit()
session[project.id] = project.password
return redirect(url_for(".list_bills"))
- else:
- form.name.data = g.project.name
- form.password.data = g.project.password
- form.contact_email.data = g.project.contact_email
- return render_template("edit_project.html", form=form)
+ if export_form.validate():
+ export_format = export_form.export_format.data
+ export_type = export_form.export_type.data
+
+ if export_type == 'transactions':
+ export = g.project.get_transactions_to_settle_bill(
+ pretty_output=True)
+ if export_type == "bills":
+ export = g.project.get_pretty_bills(
+ export_format=export_format)
+
+ if export_format == "json":
+ file2export = list_of_dicts2json(export)
+ if export_format == "csv":
+ file2export = list_of_dicts2csv(export)
+
+ return send_file(file2export,
+ attachment_filename="%s-%s.%s" %
+ (g.project.name, export_type, export_format),
+ as_attachment=True
+ )
+ else:
+ edit_form.name.data = g.project.name
+ edit_form.password.data = g.project.password
+ edit_form.contact_email.data = g.project.contact_email
+
+ return render_template("edit_project.html", edit_form=edit_form, export_form=export_form)
@main.route("//delete")