1
0
Fork 0
mirror of https://github.com/YunoHost-Apps/ihatemoney_ynh.git synced 2024-09-03 19:26:15 +02:00

Complete the REST API + Tests. Fix #27

This commit is contained in:
Alexis Metaireau 2011-10-08 13:22:18 +02:00
parent 402dbce153
commit 48bc551853
7 changed files with 353 additions and 52 deletions

View file

@ -2,8 +2,8 @@
from flask import * from flask import *
from models import db, Project, Person, Bill from models import db, Project, Person, Bill
from forms import ProjectForm from forms import ProjectForm, EditProjectForm, MemberForm, BillForm
from utils import for_all_methods from utils import for_all_methods, get_billform_for
from rest import RESTResource, need_auth# FIXME make it an ext from rest import RESTResource, need_auth# FIXME make it an ext
from werkzeug import Response from werkzeug import Response
@ -32,7 +32,7 @@ class ProjectHandler(object):
def add(self): def add(self):
form = ProjectForm(csrf_enabled=False) form = ProjectForm(csrf_enabled=False)
if form.validate(): if form.validate():
project = form.save(Project()) project = form.save()
db.session.add(project) db.session.add(project)
db.session.commit() db.session.commit()
return 201, project.id return 201, project.id
@ -40,7 +40,7 @@ class ProjectHandler(object):
@need_auth(check_project, "project") @need_auth(check_project, "project")
def get(self, project): def get(self, project):
return project return 200, project
@need_auth(check_project, "project") @need_auth(check_project, "project")
def delete(self, project): def delete(self, project):
@ -50,9 +50,9 @@ class ProjectHandler(object):
@need_auth(check_project, "project") @need_auth(check_project, "project")
def update(self, project): def update(self, project):
form = ProjectForm(csrf_enabled=False) form = EditProjectForm(csrf_enabled=False)
if form.validate(): if form.validate():
form.save(project) form.update(project)
db.session.commit() db.session.commit()
return 200, "UPDATED" return 200, "UPDATED"
return 400, form.errors return 400, form.errors
@ -61,25 +61,25 @@ class ProjectHandler(object):
class MemberHandler(object): class MemberHandler(object):
def get(self, project, member_id): def get(self, project, member_id):
member = Person.query.get(member_id) member = Person.query.get(member_id, project)
if not member or member.project != project: if not member or member.project != project:
return 404, "Not Found" return 404, "Not Found"
return member return 200, member
def list(self, project): def list(self, project):
return project.members return 200, project.members
def add(self, project): def add(self, project):
form = MemberForm(csrf_enabled=False) form = MemberForm(project, csrf_enabled=False)
if form.validate(): if form.validate():
member = Person() member = Person()
form.save(project, member) form.save(project, member)
db.session.commit() db.session.commit()
return 200, member.id return 201, member.id
return 400, form.errors return 400, form.errors
def update(self, project, member_id): def update(self, project, member_id):
form = MemberForm(csrf_enabled=False) form = MemberForm(project, csrf_enabled=False)
if form.validate(): if form.validate():
member = Person.query.get(member_id, project) member = Person.query.get(member_id, project)
form.save(project, member) form.save(project, member)
@ -99,39 +99,41 @@ class BillHandler(object):
bill = Bill.query.get(project, bill_id) bill = Bill.query.get(project, bill_id)
if not bill: if not bill:
return 404, "Not Found" return 404, "Not Found"
return bill return 200, bill
def list(self, project): def list(self, project):
return project.get_bills().all() return project.get_bills().all()
def add(self, project): def add(self, project):
form = BillForm(csrf_enabled=False) form = get_billform_for(project, True, csrf_enabled=False)
if form.validate(): if form.validate():
bill = Bill() bill = Bill()
form.save(bill) form.save(bill, project)
db.session.add(bill) db.session.add(bill)
db.session.commit() db.session.commit()
return 200, bill.id return 201, bill.id
return 400, form.errors return 400, form.errors
def update(self, project, bill_id): def update(self, project, bill_id):
form = BillForm(csrf_enabled=False) form = get_billform_for(project, True, csrf_enabled=False)
if form.validate(): if form.validate():
form.save(bill) bill = Bill.query.get(project, bill_id)
form.save(bill, project)
db.session.commit() db.session.commit()
return 200, bill.id return 200, bill.id
return 400, form.errors return 400, form.errors
def delete(self, project, bill_id): def delete(self, project, bill_id):
bill = Bill.query.delete(project, bill_id) bill = Bill.query.delete(project, bill_id)
db.session.commit()
if not bill: if not bill:
return 404, "Not Found" return 404, "Not Found"
return bill return 200, "OK"
project_resource = RESTResource( project_resource = RESTResource(
name="project", name="project",
route="/project", route="/projects",
app=api, app=api,
actions=["add", "update", "delete", "get"], actions=["add", "update", "delete", "get"],
handler=ProjectHandler()) handler=ProjectHandler())
@ -139,7 +141,7 @@ project_resource = RESTResource(
member_resource = RESTResource( member_resource = RESTResource(
name="member", name="member",
inject_name="project", inject_name="project",
route="/project/<project_id>/members", route="/projects/<project_id>/members",
app=api, app=api,
handler=MemberHandler(), handler=MemberHandler(),
authentifier=check_project) authentifier=check_project)
@ -147,7 +149,7 @@ member_resource = RESTResource(
bill_resource = RESTResource( bill_resource = RESTResource(
name="bill", name="bill",
inject_name="project", inject_name="project",
route="/project/<project_id>/bills", route="/projects/<project_id>/bills",
app=api, app=api,
handler=BillHandler(), handler=BillHandler(),
authentifier=check_project) authentifier=check_project)

View file

@ -95,12 +95,13 @@ class BillForm(Form):
validators=[Required()], widget=select_multi_checkbox) validators=[Required()], widget=select_multi_checkbox)
submit = SubmitField("Send the bill") submit = SubmitField("Send the bill")
def save(self, bill): def save(self, bill, project):
bill.payer_id=self.payer.data bill.payer_id=self.payer.data
bill.amount=self.amount.data bill.amount=self.amount.data
bill.what=self.what.data bill.what=self.what.data
bill.date=self.date.data bill.date=self.date.data
bill.owers = [Person.query.get(ower) for ower in self.payed_for.data] bill.owers = [Person.query.get(ower, project)
for ower in self.payed_for.data]
return bill return bill

View file

@ -60,8 +60,7 @@ class Project(db.Model):
This method returns the status DELETED or DEACTIVATED regarding the This method returns the status DELETED or DEACTIVATED regarding the
changes made. changes made.
""" """
person = Person.query.get_or_404(member_id) person = Person.query.get(member_id, self)
if person.project == self:
if not person.has_bills(): if not person.has_bills():
db.session.delete(person) db.session.delete(person)
db.session.commit() db.session.commit()

View file

@ -90,7 +90,7 @@ def need_auth(authentifier, name=None, remove_attr=True):
If the request is authorized, the object returned by the authentifier If the request is authorized, the object returned by the authentifier
is added to the kwargs of the method. is added to the kwargs of the method.
If not, issue a 403 Forbidden error If not, issue a 401 Unauthorized error
:authentifier: :authentifier:
The callable to check the context onto. The callable to check the context onto.
@ -114,7 +114,7 @@ def need_auth(authentifier, name=None, remove_attr=True):
del kwargs["%s_id" % name] del kwargs["%s_id" % name]
return func(*args, **kwargs) return func(*args, **kwargs)
else: else:
return 403, "Forbidden" return 401, "Unauthorized"
return wrapped return wrapped
return wrapper return wrapper
@ -126,7 +126,7 @@ def serialize(func):
""" """
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
# get the mimetype # get the mimetype
mime = request.accept_mimetypes.best_match(SERIALIZERS.keys()) mime = request.accept_mimetypes.best_match(SERIALIZERS.keys()) or "text/json"
data = func(*args, **kwargs) data = func(*args, **kwargs)
serializer = SERIALIZERS[mime] serializer = SERIALIZERS[mime]

View file

@ -2,6 +2,8 @@
import os import os
import tempfile import tempfile
import unittest import unittest
import base64
import json
from flask import session from flask import session
@ -333,5 +335,287 @@ class BudgetTestCase(TestCase):
self.assertIn("Invalid email address", resp.data) self.assertIn("Invalid email address", resp.data)
class APITestCase(TestCase):
"""Tests the API"""
def api_create(self, name, id=None, password=None, contact=None):
id = id or name
password = password or name
contact = contact or "%s@notmyidea.org" % name
return self.app.post("/api/projects", data={
'name': name,
'id': id,
'password': password,
'contact_email': contact
})
def api_add_member(self, project, name):
self.app.post("/api/projects/%s/members" % project,
data={"name": name}, headers=self.get_auth(project))
def get_auth(self, username, password=None):
password = password or username
base64string = base64.encodestring(
'%s:%s' % (username, password)).replace('\n', '')
return {"Authorization": "Basic %s" % base64string}
def assertStatus(self, expected, resp, url=""):
return self.assertEqual(expected, resp.status_code,
"%s expected %s, got %s" % (url, expected, resp.status_code))
def test_basic_auth(self):
# create a project
resp = self.api_create("raclette")
self.assertStatus(201, resp)
# try to do something on it being unauth should return a 401
resp = self.app.get("/api/projects/raclette")
self.assertStatus(401, resp)
# PUT / POST / DELETE / GET on the different resources
# should also return a 401
for verb in ('post',):
for resource in ("/raclette/members", "/raclette/bills"):
url = "/api/projects" + resource
self.assertStatus(401, getattr(self.app, verb)(url),
verb + resource)
for verb in ('get', 'delete', 'put'):
for resource in ("/raclette", "/raclette/members/1",
"/raclette/bills/1"):
url = "/api/projects" + resource
self.assertStatus(401, getattr(self.app, verb)(url),
verb + resource)
def test_project(self):
# wrong email should return an error
resp = self.app.post("/api/projects", data={
'name': "raclette",
'id': "raclette",
'password': "raclette",
'contact_email': "not-an-email"
})
self.assertTrue(400, resp.status_code)
self.assertEqual('{"contact_email": ["Invalid email address."]}', resp.data)
# create it
resp = self.api_create("raclette")
self.assertTrue(201, resp.status_code)
# create it twice should return a 400
resp = self.api_create("raclette")
self.assertTrue(400, resp.status_code)
self.assertEqual('{"id": ["This project id is already used"]}', resp.data)
# get information about it
resp = self.app.get("/api/projects/raclette",
headers=self.get_auth("raclette"))
self.assertTrue(200, resp.status_code)
expected = {
"active_members": [],
"name": "raclette",
"contact_email": "raclette@notmyidea.org",
"members": [],
"password": "raclette",
"id": "raclette"
}
self.assertDictEqual(json.loads(resp.data), expected)
# edit should work
resp = self.app.put("/api/projects/raclette", data = {
"contact_email": "yeah@notmyidea.org",
"password": "raclette",
"name": "The raclette party",
}, headers=self.get_auth("raclette"))
self.assertEqual(200, resp.status_code)
resp = self.app.get("/api/projects/raclette",
headers=self.get_auth("raclette"))
self.assertEqual(200, resp.status_code)
expected = {
"active_members": [],
"name": "The raclette party",
"contact_email": "yeah@notmyidea.org",
"members": [],
"password": "raclette",
"id": "raclette"
}
self.assertDictEqual(json.loads(resp.data), expected)
# delete should work
resp = self.app.delete("/api/projects/raclette",
headers=self.get_auth("raclette"))
self.assertEqual(200, resp.status_code)
# get should return a 401 on an unknown resource
resp = self.app.get("/api/projects/raclette",
headers=self.get_auth("raclette"))
self.assertEqual(401, resp.status_code)
def test_member(self):
# create a project
self.api_create("raclette")
# get the list of members (should be empty)
req = self.app.get("/api/projects/raclette/members",
headers=self.get_auth("raclette"))
self.assertStatus(200, req)
self.assertEqual('[]', req.data)
# add a member
req = self.app.post("/api/projects/raclette/members", data={
"name": "Alexis"
}, headers=self.get_auth("raclette"))
# the id of the new member should be returned
self.assertStatus(201, req)
self.assertEqual("1", req.data)
# the list of members should contain one member
req = self.app.get("/api/projects/raclette/members",
headers=self.get_auth("raclette"))
self.assertStatus(200, req)
self.assertEqual(len(json.loads(req.data)), 1)
# edit this member
req = self.app.put("/api/projects/raclette/members/1", data={
"name": "Fred"
}, headers=self.get_auth("raclette"))
self.assertStatus(200, req)
# get should return the new name
req = self.app.get("/api/projects/raclette/members/1",
headers=self.get_auth("raclette"))
self.assertStatus(200, req)
self.assertEqual("Fred", json.loads(req.data)["name"])
# delete a member
req = self.app.delete("/api/projects/raclette/members/1",
headers=self.get_auth("raclette"))
self.assertStatus(200, req)
# the list of members should be empty
# get the list of members (should be empty)
req = self.app.get("/api/projects/raclette/members",
headers=self.get_auth("raclette"))
self.assertStatus(200, req)
self.assertEqual('[]', req.data)
def test_bills(self):
# create a project
self.api_create("raclette")
# add members
self.api_add_member("raclette", "alexis")
self.api_add_member("raclette", "fred")
self.api_add_member("raclette", "arnaud")
# get the list of bills (should be empty)
req = self.app.get("/api/projects/raclette/bills",
headers=self.get_auth("raclette"))
self.assertStatus(200, req)
self.assertEqual("[]", req.data)
# add a bill
req = self.app.post("/api/projects/raclette/bills", data={
'date': '2011-08-10',
'what': u'fromage',
'payer': "1",
'payed_for': ["1", "2"],
'amount': '25',
}, headers=self.get_auth("raclette"))
# should return the id
self.assertStatus(201, req)
self.assertEqual(req.data, "1")
# get this bill details
req = self.app.get("/api/projects/raclette/bills/1",
headers=self.get_auth("raclette"))
# compare with the added info
self.assertStatus(200, req)
expected = {
"what": "fromage",
"payer_id": 1,
"owers": [
{"activated": True, "id": 1, "name": "alexis"},
{"activated": True, "id": 2, "name": "fred"}],
"amount": 25.0,
"date": "2011-08-10",
"id": 1}
self.assertDictEqual(expected, json.loads(req.data))
# the list of bills should lenght 1
req = self.app.get("/api/projects/raclette/bills",
headers=self.get_auth("raclette"))
self.assertStatus(200, req)
self.assertEqual(1, len(json.loads(req.data)))
# edit with errors should return an error
req = self.app.put("/api/projects/raclette/bills/1", data={
'date': '201111111-08-10', # not a date
'what': u'fromage',
'payer': "1",
'payed_for': ["1", "2"],
'amount': '25',
}, headers=self.get_auth("raclette"))
self.assertStatus(400, req)
self.assertEqual('{"date": ["This field is required."]}', req.data)
# edit a bill
req = self.app.put("/api/projects/raclette/bills/1", data={
'date': '2011-09-10',
'what': u'beer',
'payer': "2",
'payed_for': ["1", "2"],
'amount': '25',
}, headers=self.get_auth("raclette"))
# check its fields
req = self.app.get("/api/projects/raclette/bills/1",
headers=self.get_auth("raclette"))
expected = {
"what": "beer",
"payer_id": 2,
"owers": [
{"activated": True, "id": 1, "name": "alexis"},
{"activated": True, "id": 2, "name": "fred"}],
"amount": 25.0,
"date": "2011-09-10",
"id": 1}
self.assertDictEqual(expected, json.loads(req.data))
# delete a bill
req = self.app.delete("/api/projects/raclette/bills/1",
headers=self.get_auth("raclette"))
self.assertStatus(200, req)
# getting it should return a 404
req = self.app.get("/api/projects/raclette/bills/1",
headers=self.get_auth("raclette"))
self.assertStatus(404, req)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View file

@ -17,6 +17,22 @@ def slugify(value):
value = unicode(re.sub('[^\w\s-]', '', value).strip().lower()) value = unicode(re.sub('[^\w\s-]', '', value).strip().lower())
return re.sub('[-\s]+', '-', value) return re.sub('[-\s]+', '-', value)
def get_billform_for(project, set_default=True, **kwargs):
"""Return an instance of BillForm configured for a particular project.
:set_default: if set to True, on GET methods (usually when we want to
display the default form, it will call set_default on it.
"""
form = BillForm(**kwargs)
form.payed_for.choices = form.payer.choices = [(str(m.id), m.name) for m in project.active_members]
form.payed_for.default = [str(m.id) for m in project.active_members]
if set_default and request.method == "GET":
form.set_default()
return form
class Redirect303(HTTPException, RoutingException): class Redirect303(HTTPException, RoutingException):
"""Raise if the map requests a redirect. This is for example the case if """Raise if the map requests a redirect. This is for example the case if
`strict_slashes` are activated and an url that requires a trailing slash. `strict_slashes` are activated and an url that requires a trailing slash.
@ -39,4 +55,3 @@ def for_all_methods(decorator):
setattr(cls, name, decorator(method)) setattr(cls, name, decorator(method))
return cls return cls
return decorate return decorate

View file

@ -262,7 +262,7 @@ def add_bill():
if request.method == 'POST': if request.method == 'POST':
if form.validate(): if form.validate():
bill = Bill() bill = Bill()
db.session.add(form.save(bill)) db.session.add(form.save(bill, g.project))
db.session.commit() db.session.commit()
flash("The bill has been added") flash("The bill has been added")
@ -295,7 +295,7 @@ def edit_bill(bill_id):
form = get_billform_for(request, g.project, set_default=False) form = get_billform_for(request, g.project, set_default=False)
if request.method == 'POST' and form.validate(): if request.method == 'POST' and form.validate():
form.save(bill) form.save(bill, g.project)
db.session.commit() db.session.commit()
flash("The bill has been modified") flash("The bill has been modified")