patx/micropie
added mongoKV powered SessionBackend to url_shortner example to demo a more prod ready system
Commit fda0eed · patx · 2025-12-28T18:45:26-05:00
Comments
No comments yet.
Diff
diff --git a/examples/url_shortener/main.py b/examples/url_shortener/main.py
index 3e8660a..7a9b9d1 100644
--- a/examples/url_shortener/main.py
+++ b/examples/url_shortener/main.py
@@ -4,17 +4,25 @@ from secrets import choice
from micropie import App
from mongokv import Mkv
+# Import middlewares and session backends
from middlewares.rate_limit import MongoRateLimitMiddleware
from middlewares.csrf import CSRFMiddleware
+from sessions.mongo_session import MkvSessionBackend
+# EXAMPLE KEYS/URI, in production use/generate your own and save it as an
+# environment variables, do not hard code them like these demos HINT: You
+# can use `secrets.token_urlsafe(64)` to generate your CSRF secret key
URL_ROOT = "http://localhost:8000/"
MONGO_URI = "mongodb://localhost:27017"
+DB_NAME = "shorty"
CSRF_KEY = "wzWf0CsZr3LfrgPVc9RqHFVUmyXsYT-k8hnGt41bMGU"
-db = Mkv(MONGO_URI)
+# Create an mongoKV instance using our URI
+db = Mkv(MONGO_URI, db_name=DB_NAME, collection_name="urls")
+# Our main app class
class Shorty(App):
def _generate_id(self, length: int = 8) -> str:
@@ -46,7 +54,23 @@ class Shorty(App):
return await self._render_template("index.html", request=self.request)
-app = Shorty()
-app.middlewares.append(MongoRateLimitMiddleware(mongo_uri=MONGO_URI))
-app.middlewares.append(CSRFMiddleware(app=app, secret_key=CSRF_KEY))
+app = Shorty(session_backend=MkvSessionBackend(
+ mongo_uri=MONGO_URI,
+ db_name=DB_NAME
+ )
+)
+app.middlewares.append(
+ MongoRateLimitMiddleware(
+ mongo_uri=MONGO_URI,
+ allowed_hosts=None, # don't enforce host allowlist, change in prod
+ trust_proxy_headers=False, # change in prod
+ require_cf_ray=False,
+ )
+)
+app.middlewares.append(
+ CSRFMiddleware(
+ app=app,
+ secret_key=CSRF_KEY
+ )
+)
diff --git a/examples/url_shortener/middlewares/rate_limit.py b/examples/url_shortener/middlewares/rate_limit.py
index 36bbd43..983f4a3 100644
--- a/examples/url_shortener/middlewares/rate_limit.py
+++ b/examples/url_shortener/middlewares/rate_limit.py
@@ -1,16 +1,36 @@
from __future__ import annotations
+import ipaddress
from datetime import datetime, timedelta
+from typing import Set
-from motor.motor_asyncio import AsyncIOMotorClient
+from pymongo import AsyncMongoClient, ReturnDocument
from micropie import HttpMiddleware
+def _valid_ip(value: str | None) -> str | None:
+ try:
+ return str(ipaddress.ip_address(value.strip()))
+ except Exception:
+ return None
+
+
class MongoRateLimitMiddleware(HttpMiddleware):
"""
- Global MongoDB-based rate limiter.
+ Global MongoDB-based rate limiter (Heroku + Cloudflare safe).
+
+ - One document per client IP
+ - Fixed window counter
+ - Escalating temporary blocks
+ - Permanent block based on 24h violation history
+ - Fully atomic (single DB op per request)
+ - PyMongo Async API (no Motor)
+
+ Requires:
+ - MongoDB 4.2+ (aggregation pipeline updates)
"""
+ # --- rate config ---
MAX_REQUESTS = 50
WINDOW_SECONDS = 60
@@ -23,101 +43,267 @@ class MongoRateLimitMiddleware(HttpMiddleware):
def __init__(
self,
mongo_uri: str,
- db_name: str = "rate_limit",
- collection_name: str = "list",
+ db_name: str = "vegy_security",
+ collection_name: str = "rate_limits_global",
+ *,
+ allowed_hosts: Set[str] | None = None,
+ trust_proxy_headers: bool = True,
+ require_cf_ray: bool = True,
):
- self.client = AsyncIOMotorClient(mongo_uri)
+ self.client = AsyncMongoClient(mongo_uri)
self.db = self.client[db_name]
self.collection = self.db[collection_name]
- async def before_request(self, request):
+ # Security / proxy config
+ self.allowed_hosts = allowed_hosts or set()
+ self.trust_proxy_headers = trust_proxy_headers
+ self.require_cf_ray = require_cf_ray
+
+ # ---------------------------------------------------------
+ # Real client IP resolution (for Cloudflare + Heroku or similar setups)
+ # ---------------------------------------------------------
+
+ def _client_ip(self, request) -> str:
+ headers = getattr(request, "headers", {}) or {}
+
+ # 1) Optional host allow-list (prevents origin bypass)
+ if self.allowed_hosts:
+ host = (headers.get("host") or "").split(":", 1)[0].lower()
+ if host and host not in self.allowed_hosts:
+ return "unknown"
+
+ # 2) Only trust proxy headers if allowed
+ can_trust = self.trust_proxy_headers
+ if can_trust and self.require_cf_ray:
+ can_trust = bool(headers.get("cf-ray"))
+
+ if can_trust:
+ # Cloudflare headers (best)
+ for h in ("cf-connecting-ip", "true-client-ip"):
+ ip = _valid_ip(headers.get(h))
+ if ip:
+ return ip
+
+ # Standard proxy chain
+ xff = headers.get("x-forwarded-for")
+ if isinstance(xff, str):
+ ip = _valid_ip(xff.split(",", 1)[0])
+ if ip:
+ return ip
+
+ # Fallback proxy header
+ ip = _valid_ip(headers.get("x-real-ip"))
+ if ip:
+ return ip
+
+ # 3) ASGI scope fallback (Heroku router)
client = request.scope.get("client") or ("unknown", 0)
- client_ip = client[0] or "unknown"
+ return _valid_ip(client[0]) or "unknown"
+
+ # ---------------------------------------------------------
+ # Middleware hook
+ # ---------------------------------------------------------
+
+ async def before_request(self, request):
+ client_ip = self._client_ip(request)
now = datetime.utcnow()
window_start_cutoff = now - timedelta(seconds=self.WINDOW_SECONDS)
perma_window_cutoff = now - timedelta(hours=self.PERMA_WINDOW_HOURS)
key = client_ip
- doc = await self.collection.find_one({"_id": key})
- if doc and doc.get("permanent_blocked"):
- return {
- "status_code": 403,
- "body": f"Access permanently blocked for IP {client_ip}.",
- "headers": [],
- }
+ doc = await self.collection.find_one_and_update(
+ {"_id": key},
+ [
+ # 1) Baseline fields
+ {
+ "$set": {
+ "_id": key,
+ "ip": client_ip,
+ "count": {"$ifNull": ["$count", 0]},
+ "window_start": {"$ifNull": ["$window_start", now]},
+ "violations": {"$ifNull": ["$violations", 0]},
+ "blocked_until": {"$ifNull": ["$blocked_until", None]},
+ "permanent_blocked": {"$ifNull": ["$permanent_blocked", False]},
+ "permanent_blocked_at": {"$ifNull": ["$permanent_blocked_at", None]},
+ "violation_events": {"$ifNull": ["$violation_events", []]},
+ }
+ },
- if doc:
- blocked_until = doc.get("blocked_until")
- if isinstance(blocked_until, datetime) and now < blocked_until:
- retry_after = max(0, int((blocked_until - now).total_seconds()))
- return {
- "status_code": 429,
- "body": f"Too many requests from {client_ip}. Temporarily blocked.",
- "headers": [("Retry-After", str(retry_after))],
- }
-
- if (
- not doc
- or doc.get("window_start") is None
- or doc["window_start"] < window_start_cutoff
- ):
- violations = doc.get("violations", 0) if doc else 0
-
- await self.collection.replace_one(
- {"_id": key},
+ # 2) Prune old violation events
{
- "_id": key,
- "ip": client_ip,
- "count": 1,
- "window_start": now,
- "violations": violations,
- "blocked_until": None,
- "permanent_blocked": doc.get("permanent_blocked", False) if doc else False,
- "permanent_blocked_at": doc.get("permanent_blocked_at") if doc else None,
- "violation_events": doc.get("violation_events", []) if doc else [],
+ "$set": {
+ "violation_events": {
+ "$filter": {
+ "input": "$violation_events",
+ "as": "t",
+ "cond": {"$gte": ["$$t", perma_window_cutoff]},
+ }
+ }
+ }
},
- upsert=True,
- )
- return None
- count = int(doc.get("count", 0))
+ # 3) Are we currently blocked?
+ {
+ "$set": {
+ "_blocked_now": {
+ "$or": [
+ "$permanent_blocked",
+ {
+ "$and": [
+ {"$ne": ["$blocked_until", None]},
+ {"$gt": ["$blocked_until", now]},
+ ]
+ },
+ ]
+ }
+ }
+ },
- if count >= self.MAX_REQUESTS:
- violations = int(doc.get("violations", 0)) + 1
+ # 4) Update window/count atomically (only if not blocked)
+ {
+ "$set": {
+ "_window_expired": {
+ "$cond": [
+ "$_blocked_now",
+ False,
+ {"$lt": ["$window_start", window_start_cutoff]},
+ ]
+ }
+ }
+ },
+ {
+ "$set": {
+ "window_start": {
+ "$cond": [
+ "$_blocked_now",
+ "$window_start",
+ {"$cond": ["$_window_expired", now, "$window_start"]},
+ ]
+ },
+ "count": {
+ "$cond": [
+ "$_blocked_now",
+ "$count",
+ {"$cond": ["$_window_expired", 1, {"$add": ["$count", 1]}]},
+ ]
+ },
+ }
+ },
- await self.collection.update_one(
- {"_id": key},
+ # 5) Over limit?
{
- "$set": {"violations": violations},
- "$push": {"violation_events": now},
- "$pull": {"violation_events": {"$lt": perma_window_cutoff}},
+ "$set": {
+ "_over_limit": {
+ "$and": [
+ {"$not": "$_blocked_now"},
+ {"$gt": ["$count", self.MAX_REQUESTS]},
+ ]
+ }
+ }
},
- )
- doc = await self.collection.find_one({"_id": key})
- events_last_24h = len((doc or {}).get("violation_events", []))
+ # 6) Record violation if over limit
+ {
+ "$set": {
+ "violations": {
+ "$cond": ["$_over_limit", {"$add": ["$violations", 1]}, "$violations"]
+ },
+ "violation_events": {
+ "$cond": [
+ "$_over_limit",
+ {"$concatArrays": ["$violation_events", [now]]},
+ "$violation_events",
+ ]
+ },
+ }
+ },
- update = {}
+ # 7) Temporary block escalation
+ {
+ "$set": {
+ "blocked_until": {
+ "$cond": [
+ {
+ "$and": [
+ "$_over_limit",
+ {"$gte": ["$violations", self.BLOCK_AFTER_VIOLATIONS]},
+ ]
+ },
+ now + timedelta(seconds=self.BLOCK_FOR_SECONDS),
+ "$blocked_until",
+ ]
+ }
+ }
+ },
- if violations >= self.BLOCK_AFTER_VIOLATIONS:
- update["blocked_until"] = now + timedelta(seconds=self.BLOCK_FOR_SECONDS)
+ # 8) Permanent block escalation
+ {"$set": {"_events_24h": {"$size": "$violation_events"}}},
+ {
+ "$set": {
+ "permanent_blocked": {
+ "$cond": [
+ {
+ "$and": [
+ "$_over_limit",
+ {"$gte": ["$_events_24h", self.PERMA_BLOCK_AFTER]},
+ ]
+ },
+ True,
+ "$permanent_blocked",
+ ]
+ },
+ "permanent_blocked_at": {
+ "$cond": [
+ {
+ "$and": [
+ "$_over_limit",
+ {"$gte": ["$_events_24h", self.PERMA_BLOCK_AFTER]},
+ {"$eq": ["$permanent_blocked_at", None]},
+ ]
+ },
+ now,
+ "$permanent_blocked_at",
+ ]
+ },
+ }
+ },
+
+ # 9) Cleanup temp fields
+ {"$unset": ["_blocked_now", "_window_expired", "_over_limit", "_events_24h"]},
+ ],
+ upsert=True,
+ return_document=ReturnDocument.AFTER,
+ projection={"count": 1, "blocked_until": 1, "permanent_blocked": 1},
+ )
- if events_last_24h >= self.PERMA_BLOCK_AFTER:
- update["permanent_blocked"] = True
- update["permanent_blocked_at"] = now
+ doc = doc or {}
- if update:
- await self.collection.update_one({"_id": key}, {"$set": update})
+ # --- responses ---
+ if doc.get("permanent_blocked"):
+ return {
+ "status_code": 403,
+ "body": f"Access permanently blocked for IP {client_ip}.",
+ "headers": [],
+ }
+
+ blocked_until = doc.get("blocked_until")
+ if isinstance(blocked_until, datetime) and now < blocked_until:
+ retry_after = max(0, int((blocked_until - now).total_seconds()))
+ return {
+ "status_code": 429,
+ "body": f"Too many requests from {client_ip}. Temporarily blocked.",
+ "headers": [("Retry-After", str(retry_after))],
+ }
+ if int(doc.get("count", 0)) > self.MAX_REQUESTS:
return {
"status_code": 429,
"body": f"Rate limit exceeded for IP {client_ip}.",
"headers": [],
}
- await self.collection.update_one({"_id": key}, {"$inc": {"count": 1}})
return None
async def after_request(self, request, status_code, response_body, extra_headers):
diff --git a/examples/url_shortener/sessions/__init__.py b/examples/url_shortener/sessions/__init__.py
new file mode 100644
index 0000000..ea3f22e
--- /dev/null
+++ b/examples/url_shortener/sessions/__init__.py
@@ -0,0 +1,3 @@
+from .mongo_session import MkvSessionBackend
+
+__all__ = ["MkvSessionBackend"]
diff --git a/examples/url_shortener/sessions/mongo_session.py b/examples/url_shortener/sessions/mongo_session.py
new file mode 100644
index 0000000..1b47583
--- /dev/null
+++ b/examples/url_shortener/sessions/mongo_session.py
@@ -0,0 +1,94 @@
+import time
+from typing import Any, Dict, Optional
+
+from mongokv import Mkv
+from micropie import SessionBackend
+
+
+class MkvSessionBackend(SessionBackend):
+ """
+ Session backend backed by mongokv.Mkv.
+
+ Storage schema (per session_id):
+ key = session_id
+ value = {
+ "data": { ...session dict... },
+ "expires_at": <unix_epoch_seconds>
+ }
+
+ Notes:
+ - Expiration is enforced on load (lazy cleanup).
+ - save(..., {}, 0) deletes (matches MicroPie logout behavior).
+ """
+
+ def __init__(
+ self,
+ mongo_uri: str,
+ db_name: str,
+ collection_name: str = "sessions",
+ *,
+ key_prefix: str = "sess:",
+ ) -> None:
+ self.store = Mkv(mongo_uri, db_name=db_name, collection_name=collection_name)
+ self.key_prefix = key_prefix
+
+ def _k(self, session_id: str) -> str:
+ return f"{self.key_prefix}{session_id}"
+
+ async def load(self, session_id: str) -> Dict[str, Any]:
+ if not session_id:
+ return {}
+
+ key = self._k(session_id)
+
+ try:
+ payload = await self.store.get(key)
+ except KeyError:
+ return {}
+ except Exception:
+ # If you prefer, log this instead of swallowing.
+ return {}
+
+ if not isinstance(payload, dict):
+ # Corrupt/unexpected; treat as empty and delete
+ try:
+ await self.store.remove(key)
+ except Exception:
+ pass
+ return {}
+
+ expires_at = payload.get("expires_at")
+ if isinstance(expires_at, (int, float)) and time.time() > float(expires_at):
+ # Expired: delete and return empty
+ try:
+ await self.store.remove(key)
+ except Exception:
+ pass
+ return {}
+
+ data = payload.get("data", {})
+ return data if isinstance(data, dict) else {}
+
+ async def save(self, session_id: str, data: Dict[str, Any], timeout: int) -> None:
+ if not session_id:
+ return
+
+ key = self._k(session_id)
+
+ # MicroPie uses save(session_id, {}, 0) for logout/delete
+ if not data or timeout <= 0:
+ try:
+ await self.store.remove(key)
+ except Exception:
+ pass
+ return
+
+ expires_at = time.time() + int(timeout)
+
+ payload = {
+ "data": data,
+ "expires_at": expires_at,
+ }
+
+ await self.store.set(key, payload)
+