diff --git a/server.py b/server.py index bc89a05..5e2d2ee 100644 --- a/server.py +++ b/server.py @@ -17,15 +17,19 @@ app = Sanic() RATE_LIMIT_DB = {} # to prevent DDoS or bounce attack attempt or something like that -RATE_LIMIT_SECONDS = 5 - +# Can't do more than 10 requests in a 300-seconds window +RATE_LIMIT_SECONDS = 300 +RATE_LIMIT_NB_REQUESTS = 10 def clear_rate_limit_db(now): to_delete = [] "Remove too old rate limit values" - for key, value in RATE_LIMIT_DB.items(): - if now - value > RATE_LIMIT_SECONDS: + for key, times in RATE_LIMIT_DB.items(): + # Remove values older RATE_LIMIT_SECONDS + RATE_LIMIT_DB[key] = [t for t in times if now - t < RATE_LIMIT_SECONDS] + # If list is empty, remove the key + if RATE_LIMIT_DB[key] == []: # a dictionnary can't be modified during iteration so delegate this # operation to_delete.append(key) @@ -36,17 +40,21 @@ def clear_rate_limit_db(now): def check_rate_limit(key, now): - if key in RATE_LIMIT_DB: - since_last_attempt = now - RATE_LIMIT_DB[key] - if since_last_attempt < RATE_LIMIT_SECONDS: - logger.info(f"Rate limit reached for {key}, can retry in {int(RATE_LIMIT_SECONDS - since_last_attempt)} seconds") - return json_response({ - "status": "error", - "code": "error_rate_limit", - "content": f"Rate limit reached for this domain or ip, retry in {int(RATE_LIMIT_SECONDS - since_last_attempt)} seconds", - }, status=400) + # If there are more recent attempts than allowed + if key in RATE_LIMIT_DB and len(RATE_LIMIT_DB[key]) > RATE_LIMIT_NB_REQUESTS: + oldest_attempt = RATE_LIMIT_DB[key][0] + logger.info(f"Rate limit reached for {key}, can retry in {int(RATE_LIMIT_SECONDS - now + oldest_attempt)} seconds") + return json_response({ + "status": "error", + "code": "error_rate_limit", + "content": f"Rate limit reached for this domain or ip, retry in {int(RATE_LIMIT_SECONDS - now + oldest_attempt)} seconds", + }, status=400) - RATE_LIMIT_DB[key] = time.time() + # In any case, add this attempt to the DB + if key not in RATE_LIMIT_DB: + RATE_LIMIT_DB[key] = [now] + else: + RATE_LIMIT_DB[key].append(now) async def check_port_is_open(ip, port):