fix(rate-limit): remove conflicting $push/$pull ops, switch to scalar violation tracking

Commit ac4b6d3 · patx · 2025-12-08T22:29:38-05:00

Changeset
ac4b6d3bdc9c3dc8f86619de455f6c13bd9ea4cf
Parents
f800cc39f637f9cf22c98b22bd29983d97095911

View source at this commit

Comments

No comments yet.

Log in to comment

Diff

diff --git a/examples/middleware/rate_limit.py b/examples/middleware/rate_limit.py
index eb2e767..a61df36 100644
--- a/examples/middleware/rate_limit.py
+++ b/examples/middleware/rate_limit.py
@@ -1,6 +1,7 @@
 from datetime import datetime, timedelta
 from micropie import App, HttpMiddleware
 from motor.motor_asyncio import AsyncIOMotorClient
+from pymongo.errors import PyMongoError
 
 
 class MongoRateLimitMiddleware(HttpMiddleware):
@@ -17,7 +18,7 @@ class MongoRateLimitMiddleware(HttpMiddleware):
     WINDOW_SECONDS = 60            # window length in seconds
 
     BLOCK_AFTER_VIOLATIONS = 3     # how many windows exceeded before temp block
-    BLOCK_FOR_SECONDS = 900       # how long to temporarily block in seconds
+    BLOCK_FOR_SECONDS = 900        # how long to temporarily block (seconds)
 
     PERMA_WINDOW_HOURS = 24        # lookback window for permanent block
     PERMA_BLOCK_AFTER = 10         # violations in window before permanent block
@@ -28,7 +29,7 @@ class MongoRateLimitMiddleware(HttpMiddleware):
         db_name: str = "vegy_security",
         collection_name: str = "rate_limits_global",
     ):
-        self.client = AsyncIOMotorClient(mongo_uri)
+        self.client = motor.motor_asyncio.AsyncIOMotorClient(mongo_uri)
         self.db = self.client[db_name]
         self.collection = self.db[collection_name]
 
@@ -42,7 +43,11 @@ class MongoRateLimitMiddleware(HttpMiddleware):
 
         key = client_ip  # one document per IP
 
-        doc = await self.collection.find_one({"_id": key})
+        try:
+            doc = await self.collection.find_one({"_id": key})
+        except PyMongoError:
+            # If Mongo is unhappy, don't take the whole app down.
+            return None
 
         # 0. Permanent block check
         if doc and doc.get("permanent_blocked"):
@@ -73,22 +78,30 @@ class MongoRateLimitMiddleware(HttpMiddleware):
             or doc["window_start"] < window_start_cutoff
         ):
             violations = doc.get("violations", 0) if doc else 0
+            permanent_blocked = doc.get("permanent_blocked", False) if doc else False
+            first_violation_at = doc.get("first_violation_at") if doc else None
+            violation_count_window = doc.get("violation_count_window", 0) if doc else 0
+
+            try:
+                await self.collection.replace_one(
+                    {"_id": key},
+                    {
+                        "_id": key,
+                        "ip": client_ip,
+                        "count": 1,
+                        "window_start": now,
+                        "violations": violations,
+                        "blocked_until": None,
+                        "permanent_blocked": permanent_blocked,
+                        "first_violation_at": first_violation_at,
+                        "violation_count_window": violation_count_window,
+                    },
+                    upsert=True,
+                )
+            except PyMongoError:
+                # Soft-fail if Mongo is down
+                return None
 
-            await self.collection.replace_one(
-                {"_id": key},
-                {
-                    "_id": key,
-                    "ip": client_ip,
-                    "count": 1,
-                    "window_start": now,
-                    "violations": violations,
-                    "blocked_until": None,
-                    "permanent_blocked": doc.get("permanent_blocked", False) if doc else False,
-                    # keep a small history of violation events
-                    "violation_events": doc.get("violation_events", []) if doc else [],
-                },
-                upsert=True,
-            )
             return None  # allow request
 
         # 3. Window active -> check count
@@ -98,40 +111,45 @@ class MongoRateLimitMiddleware(HttpMiddleware):
             # Exceeded this window
             violations = doc.get("violations", 0) + 1
 
-            # Update violations + violation_events (pull old, push new)
-            await self.collection.update_one(
-                {"_id": key},
-                {
-                    "$set": {
-                        "violations": violations,
-                    },
-                    "$push": {"violation_events": now},
-                    "$pull": {"violation_events": {"$lt": perma_window_cutoff}},
-                },
-            )
+            first_violation_at = doc.get("first_violation_at")
+            violation_count_window = doc.get("violation_count_window", 0)
 
-            # Re-fetch to inspect updated violation_events
-            doc = await self.collection.find_one({"_id": key})
-            events = doc.get("violation_events", [])
-            events_last_24h = len(events)
+            # Reset 24h window if outside lookback
+            if not first_violation_at or first_violation_at < perma_window_cutoff:
+                first_violation_at = now
+                violation_count_window = 1
+            else:
+                violation_count_window += 1
 
-            update = {}
+            update_fields = {
+                "violations": violations,
+                "first_violation_at": first_violation_at,
+                "violation_count_window": violation_count_window,
+            }
 
-            # Escalate to temporary block if too many violations overall
+            # Temporary block if too many violations overall
             if violations >= self.BLOCK_AFTER_VIOLATIONS:
-                blocked_until = now + timedelta(seconds=self.BLOCK_FOR_SECONDS)
-                update["blocked_until"] = blocked_until
+                update_fields["blocked_until"] = now + timedelta(
+                    seconds=self.BLOCK_FOR_SECONDS
+                )
 
             # Permanent block if too many violations in last 24 hours
-            if events_last_24h >= self.PERMA_BLOCK_AFTER:
-                update["permanent_blocked"] = True
-                update["permanent_blocked_at"] = now
+            if violation_count_window >= self.PERMA_BLOCK_AFTER:
+                update_fields["permanent_blocked"] = True
+                update_fields["permanent_blocked_at"] = now
 
-            if update:
+            try:
                 await self.collection.update_one(
                     {"_id": key},
-                    {"$set": update},
+                    {"$set": update_fields},
                 )
+            except PyMongoError:
+                # If the write fails, still return 429 so the attacker doesn't get through.
+                return {
+                    "status_code": 429,
+                    "body": f"Rate limit exceeded for IP {client_ip}.",
+                    "headers": [],
+                }
 
             return {
                 "status_code": 429,
@@ -140,14 +158,19 @@ class MongoRateLimitMiddleware(HttpMiddleware):
             }
 
         # 4. Still within limit -> increment count
-        await self.collection.update_one(
-            {"_id": key},
-            {"$inc": {"count": 1}},
-        )
+        try:
+            await self.collection.update_one(
+                {"_id": key},
+                {"$inc": {"count": 1}},
+            )
+        except PyMongoError:
+            # Soft-fail on logging error
+            return None
 
         return None  # allow request
 
     async def after_request(self, request, status_code, response_body, extra_headers):
+        # No-op for now
         pass