diff --git a/server.py b/server.py index 0ff6f73..b3ffa95 100644 --- a/server.py +++ b/server.py @@ -1,3 +1,4 @@ +import time import asyncio import aiodns import aiohttp @@ -10,6 +11,26 @@ from sanic.exceptions import InvalidUsage app = Sanic() +# keep that in memory +RATE_LIMIT_DB = {} + +# to prevent DDoS or bounce attack attempt or something like that +RATE_LIMIT_SECONDS = 5 + + +def clear_rate_limit_db(now): + to_delete = [] + + "Remove too old rate limit values" + for key, value in RATE_LIMIT_DB.items(): + if now - value > RATE_LIMIT_SECONDS: + # a dictionnary can't be modified during iteration so delegate this + # operation + to_delete.append(key) + + for key in to_delete: + del RATE_LIMIT_DB[key] + async def query_dns(host, dns_entry_type): loop = asyncio.get_event_loop() @@ -27,8 +48,24 @@ async def query_dns(host, dns_entry_type): @app.route("/check/", methods=["POST"]) async def check_http(request): + # this is supposed to be a fast operation if run enough + now = time.time() + clear_rate_limit_db(now) + ip = request.ip + if ip in RATE_LIMIT_DB: + since_last_attempt = now - RATE_LIMIT_DB[ip] + if since_last_attempt < RATE_LIMIT_SECONDS: + logger.info(f"Rate limite {ip}, can retry in {int(RATE_LIMIT_SECONDS - since_last_attempt)} seconds") + return json_response({ + "status": "error", + "code": "error_rate_limit", + "content": f"Rate limit on ip, retry in {int(RATE_LIMIT_SECONDS - since_last_attempt)} seconds", + }) + + RATE_LIMIT_DB[ip] = time.time() + try: data = request.json except InvalidUsage: @@ -49,6 +86,18 @@ async def check_http(request): domain = data["domain"] + if domain in RATE_LIMIT_DB: + since_last_attempt = now - RATE_LIMIT_DB[domain] + if since_last_attempt < RATE_LIMIT_SECONDS: + logger.info(f"Rate limite {domain}, can retry in {int(RATE_LIMIT_SECONDS - since_last_attempt)} seconds") + return json_response({ + "status": "error", + "code": "error_rate_limit", + "content": f"Rate limit on domain, retry in {int(RATE_LIMIT_SECONDS - since_last_attempt)} seconds", + }) + + RATE_LIMIT_DB[domain] = time.time() + if not validators.domain(domain): logger.info(f"Invalid request, is not in the right format (domain is : {domain})") return json_response({ @@ -128,20 +177,6 @@ async def check_http(request): "content": "an error happen while trying to get your domain, it's very likely unreachable", }) - # [x] - get ip - # [x] - get request json - # [x] - in request json get domain target - # [x] - validate domain is in correct format - # [x] - check dns that domain == ip - # [x] - if not, complain - # [x] - handle ipv6 - # [x] - if everything is ok, try to get with http - # [x] - ADD TIMEOUT - # [x] - try/catch, if everything is ok → response ok - # [x] - otherwise reponse with exception - # [x] - create error codes - # [ ] - rate limit - return json_response({"status": "ok"})