from __future__ import annotations
import ipaddress
from datetime import datetime, timedelta
from typing import Set
from pymongo import AsyncMongoClient, ReturnDocument
from micropie import HttpMiddleware
def _valid_ip(value: str | None) -> str | None:
try:
return str(ipaddress.ip_address(value.strip()))
except Exception:
return None
class MongoRateLimitMiddleware(HttpMiddleware):
"""
Global MongoDB-based rate limiter (Heroku + Cloudflare safe).
- One document per client IP
- Fixed window counter
- Escalating temporary blocks
- Permanent block based on 24h violation history
- Fully atomic (single DB op per request)
- PyMongo Async API (no Motor)
Requires:
- MongoDB 4.2+ (aggregation pipeline updates)
"""
# --- rate config ---
MAX_REQUESTS = 50
WINDOW_SECONDS = 60
BLOCK_AFTER_VIOLATIONS = 3
BLOCK_FOR_SECONDS = 900
PERMA_WINDOW_HOURS = 24
PERMA_BLOCK_AFTER = 10
def __init__(
self,
mongo_uri: str,
db_name: str,
collection_name: str = "rate_limits_global",
*,
allowed_hosts: Set[str] | None = None,
trust_proxy_headers: bool = True,
require_cf_ray: bool = True,
):
self.client = AsyncMongoClient(mongo_uri)
self.db = self.client[db_name]
self.collection = self.db[collection_name]
# Security / proxy config
self.allowed_hosts = allowed_hosts or set()
self.trust_proxy_headers = trust_proxy_headers
self.require_cf_ray = require_cf_ray
# ---------------------------------------------------------
# Real client IP resolution (for Cloudflare + Heroku or similar setups)
# ---------------------------------------------------------
def _client_ip(self, request) -> str:
headers = getattr(request, "headers", {}) or {}
# 1) Optional host allow-list (prevents origin bypass)
if self.allowed_hosts:
host = (headers.get("host") or "").split(":", 1)[0].lower()
if host and host not in self.allowed_hosts:
return "unknown"
# 2) Only trust proxy headers if allowed
can_trust = self.trust_proxy_headers
if can_trust and self.require_cf_ray:
can_trust = bool(headers.get("cf-ray"))
if can_trust:
# Cloudflare headers (best)
for h in ("cf-connecting-ip", "true-client-ip"):
ip = _valid_ip(headers.get(h))
if ip:
return ip
# Standard proxy chain
xff = headers.get("x-forwarded-for")
if isinstance(xff, str):
ip = _valid_ip(xff.split(",", 1)[0])
if ip:
return ip
# Fallback proxy header
ip = _valid_ip(headers.get("x-real-ip"))
if ip:
return ip
# 3) ASGI scope fallback (Heroku router)
client = request.scope.get("client") or ("unknown", 0)
return _valid_ip(client[0]) or "unknown"
# ---------------------------------------------------------
# Middleware hook
# ---------------------------------------------------------
async def before_request(self, request):
client_ip = self._client_ip(request)
now = datetime.utcnow()
window_start_cutoff = now - timedelta(seconds=self.WINDOW_SECONDS)
perma_window_cutoff = now - timedelta(hours=self.PERMA_WINDOW_HOURS)
key = client_ip
doc = await self.collection.find_one_and_update(
{"_id": key},
[
# 1) Baseline fields
{
"$set": {
"_id": key,
"ip": client_ip,
"count": {"$ifNull": ["$count", 0]},
"window_start": {"$ifNull": ["$window_start", now]},
"violations": {"$ifNull": ["$violations", 0]},
"blocked_until": {"$ifNull": ["$blocked_until", None]},
"permanent_blocked": {"$ifNull": ["$permanent_blocked", False]},
"permanent_blocked_at": {
"$ifNull": ["$permanent_blocked_at", None]
},
"violation_events": {"$ifNull": ["$violation_events", []]},
}
},
# 2) Prune old violation events
{
"$set": {
"violation_events": {
"$filter": {
"input": "$violation_events",
"as": "t",
"cond": {"$gte": ["$$t", perma_window_cutoff]},
}
}
}
},
# 3) Are we currently blocked?
{
"$set": {
"_blocked_now": {
"$or": [
"$permanent_blocked",
{
"$and": [
{"$ne": ["$blocked_until", None]},
{"$gt": ["$blocked_until", now]},
]
},
]
}
}
},
# 4) Update window/count atomically (only if not blocked)
{
"$set": {
"_window_expired": {
"$cond": [
"$_blocked_now",
False,
{"$lt": ["$window_start", window_start_cutoff]},
]
}
}
},
{
"$set": {
"window_start": {
"$cond": [
"$_blocked_now",
"$window_start",
{"$cond": ["$_window_expired", now, "$window_start"]},
]
},
"count": {
"$cond": [
"$_blocked_now",
"$count",
{
"$cond": [
"$_window_expired",
1,
{"$add": ["$count", 1]},
]
},
]
},
}
},
# 5) Over limit?
{
"$set": {
"_over_limit": {
"$and": [
{"$not": "$_blocked_now"},
{"$gt": ["$count", self.MAX_REQUESTS]},
]
}
}
},
# 6) Record violation if over limit
{
"$set": {
"violations": {
"$cond": [
"$_over_limit",
{"$add": ["$violations", 1]},
"$violations",
]
},
"violation_events": {
"$cond": [
"$_over_limit",
{"$concatArrays": ["$violation_events", [now]]},
"$violation_events",
]
},
}
},
# 7) Temporary block escalation
{
"$set": {
"blocked_until": {
"$cond": [
{
"$and": [
"$_over_limit",
{
"$gte": [
"$violations",
self.BLOCK_AFTER_VIOLATIONS,
]
},
]
},
now + timedelta(seconds=self.BLOCK_FOR_SECONDS),
"$blocked_until",
]
}
}
},
# 8) Permanent block escalation
{"$set": {"_events_24h": {"$size": "$violation_events"}}},
{
"$set": {
"permanent_blocked": {
"$cond": [
{
"$and": [
"$_over_limit",
{
"$gte": [
"$_events_24h",
self.PERMA_BLOCK_AFTER,
]
},
]
},
True,
"$permanent_blocked",
]
},
"permanent_blocked_at": {
"$cond": [
{
"$and": [
"$_over_limit",
{
"$gte": [
"$_events_24h",
self.PERMA_BLOCK_AFTER,
]
},
{"$eq": ["$permanent_blocked_at", None]},
]
},
now,
"$permanent_blocked_at",
]
},
}
},
# 9) Cleanup temp fields
{
"$unset": [
"_blocked_now",
"_window_expired",
"_over_limit",
"_events_24h",
]
},
],
upsert=True,
return_document=ReturnDocument.AFTER,
projection={"count": 1, "blocked_until": 1, "permanent_blocked": 1},
)
doc = doc or {}
# --- responses ---
if doc.get("permanent_blocked"):
return {
"status_code": 403,
"body": f"Access permanently blocked for IP {client_ip}.",
"headers": [],
}
blocked_until = doc.get("blocked_until")
if isinstance(blocked_until, datetime) and now < blocked_until:
retry_after = max(0, int((blocked_until - now).total_seconds()))
return {
"status_code": 429,
"body": f"Too many requests from {client_ip}. Temporarily blocked.",
"headers": [("Retry-After", str(retry_after))],
}
if int(doc.get("count", 0)) > self.MAX_REQUESTS:
return {
"status_code": 429,
"body": f"Rate limit exceeded for IP {client_ip}.",
"headers": [],
}
return None
async def after_request(self, request, status_code, response_body, extra_headers):
return None