patx/micropie
fix(rate-limit): remove conflicting $push/$pull ops, switch to scalar violation tracking
Commit ac4b6d3 · patx · 2025-12-08T22:29:38-05:00
Comments
No comments yet.
Diff
diff --git a/examples/middleware/rate_limit.py b/examples/middleware/rate_limit.py
index eb2e767..a61df36 100644
--- a/examples/middleware/rate_limit.py
+++ b/examples/middleware/rate_limit.py
@@ -1,6 +1,7 @@
from datetime import datetime, timedelta
from micropie import App, HttpMiddleware
from motor.motor_asyncio import AsyncIOMotorClient
+from pymongo.errors import PyMongoError
class MongoRateLimitMiddleware(HttpMiddleware):
@@ -17,7 +18,7 @@ class MongoRateLimitMiddleware(HttpMiddleware):
WINDOW_SECONDS = 60 # window length in seconds
BLOCK_AFTER_VIOLATIONS = 3 # how many windows exceeded before temp block
- BLOCK_FOR_SECONDS = 900 # how long to temporarily block in seconds
+ BLOCK_FOR_SECONDS = 900 # how long to temporarily block (seconds)
PERMA_WINDOW_HOURS = 24 # lookback window for permanent block
PERMA_BLOCK_AFTER = 10 # violations in window before permanent block
@@ -28,7 +29,7 @@ class MongoRateLimitMiddleware(HttpMiddleware):
db_name: str = "vegy_security",
collection_name: str = "rate_limits_global",
):
- self.client = AsyncIOMotorClient(mongo_uri)
+ self.client = motor.motor_asyncio.AsyncIOMotorClient(mongo_uri)
self.db = self.client[db_name]
self.collection = self.db[collection_name]
@@ -42,7 +43,11 @@ class MongoRateLimitMiddleware(HttpMiddleware):
key = client_ip # one document per IP
- doc = await self.collection.find_one({"_id": key})
+ try:
+ doc = await self.collection.find_one({"_id": key})
+ except PyMongoError:
+ # If Mongo is unhappy, don't take the whole app down.
+ return None
# 0. Permanent block check
if doc and doc.get("permanent_blocked"):
@@ -73,22 +78,30 @@ class MongoRateLimitMiddleware(HttpMiddleware):
or doc["window_start"] < window_start_cutoff
):
violations = doc.get("violations", 0) if doc else 0
+ permanent_blocked = doc.get("permanent_blocked", False) if doc else False
+ first_violation_at = doc.get("first_violation_at") if doc else None
+ violation_count_window = doc.get("violation_count_window", 0) if doc else 0
+
+ try:
+ await self.collection.replace_one(
+ {"_id": key},
+ {
+ "_id": key,
+ "ip": client_ip,
+ "count": 1,
+ "window_start": now,
+ "violations": violations,
+ "blocked_until": None,
+ "permanent_blocked": permanent_blocked,
+ "first_violation_at": first_violation_at,
+ "violation_count_window": violation_count_window,
+ },
+ upsert=True,
+ )
+ except PyMongoError:
+ # Soft-fail if Mongo is down
+ return None
- await self.collection.replace_one(
- {"_id": key},
- {
- "_id": key,
- "ip": client_ip,
- "count": 1,
- "window_start": now,
- "violations": violations,
- "blocked_until": None,
- "permanent_blocked": doc.get("permanent_blocked", False) if doc else False,
- # keep a small history of violation events
- "violation_events": doc.get("violation_events", []) if doc else [],
- },
- upsert=True,
- )
return None # allow request
# 3. Window active -> check count
@@ -98,40 +111,45 @@ class MongoRateLimitMiddleware(HttpMiddleware):
# Exceeded this window
violations = doc.get("violations", 0) + 1
- # Update violations + violation_events (pull old, push new)
- await self.collection.update_one(
- {"_id": key},
- {
- "$set": {
- "violations": violations,
- },
- "$push": {"violation_events": now},
- "$pull": {"violation_events": {"$lt": perma_window_cutoff}},
- },
- )
+ first_violation_at = doc.get("first_violation_at")
+ violation_count_window = doc.get("violation_count_window", 0)
- # Re-fetch to inspect updated violation_events
- doc = await self.collection.find_one({"_id": key})
- events = doc.get("violation_events", [])
- events_last_24h = len(events)
+ # Reset 24h window if outside lookback
+ if not first_violation_at or first_violation_at < perma_window_cutoff:
+ first_violation_at = now
+ violation_count_window = 1
+ else:
+ violation_count_window += 1
- update = {}
+ update_fields = {
+ "violations": violations,
+ "first_violation_at": first_violation_at,
+ "violation_count_window": violation_count_window,
+ }
- # Escalate to temporary block if too many violations overall
+ # Temporary block if too many violations overall
if violations >= self.BLOCK_AFTER_VIOLATIONS:
- blocked_until = now + timedelta(seconds=self.BLOCK_FOR_SECONDS)
- update["blocked_until"] = blocked_until
+ update_fields["blocked_until"] = now + timedelta(
+ seconds=self.BLOCK_FOR_SECONDS
+ )
# Permanent block if too many violations in last 24 hours
- if events_last_24h >= self.PERMA_BLOCK_AFTER:
- update["permanent_blocked"] = True
- update["permanent_blocked_at"] = now
+ if violation_count_window >= self.PERMA_BLOCK_AFTER:
+ update_fields["permanent_blocked"] = True
+ update_fields["permanent_blocked_at"] = now
- if update:
+ try:
await self.collection.update_one(
{"_id": key},
- {"$set": update},
+ {"$set": update_fields},
)
+ except PyMongoError:
+ # If the write fails, still return 429 so the attacker doesn't get through.
+ return {
+ "status_code": 429,
+ "body": f"Rate limit exceeded for IP {client_ip}.",
+ "headers": [],
+ }
return {
"status_code": 429,
@@ -140,14 +158,19 @@ class MongoRateLimitMiddleware(HttpMiddleware):
}
# 4. Still within limit -> increment count
- await self.collection.update_one(
- {"_id": key},
- {"$inc": {"count": 1}},
- )
+ try:
+ await self.collection.update_one(
+ {"_id": key},
+ {"$inc": {"count": 1}},
+ )
+ except PyMongoError:
+ # Soft-fail on logging error
+ return None
return None # allow request
async def after_request(self, request, status_code, response_body, extra_headers):
+ # No-op for now
pass