added mongoKV powered SessionBackend to url_shortner example to demo a more prod ready system

Commit fda0eed · patx · 2025-12-28T18:45:26-05:00

Changeset
fda0eedb6ad6ea2f9f685f1ae778d48cc9bc9d30
Parents
c29cf6b95e6ad28b8ffc642802ae6794f726c6d5

View source at this commit

Comments

No comments yet.

Log in to comment

Diff

diff --git a/examples/url_shortener/main.py b/examples/url_shortener/main.py
index 3e8660a..7a9b9d1 100644
--- a/examples/url_shortener/main.py
+++ b/examples/url_shortener/main.py
@@ -4,17 +4,25 @@ from secrets import choice
 from micropie import App
 from mongokv import Mkv
 
+# Import middlewares and session backends
 from middlewares.rate_limit import MongoRateLimitMiddleware
 from middlewares.csrf import CSRFMiddleware
+from sessions.mongo_session import MkvSessionBackend
 
 
+# EXAMPLE KEYS/URI, in production use/generate your own and save it as an 
+# environment variables, do not hard code them like these demos HINT: You 
+# can use `secrets.token_urlsafe(64)` to generate your CSRF secret key
 URL_ROOT = "http://localhost:8000/"
 MONGO_URI = "mongodb://localhost:27017"
+DB_NAME = "shorty"
 CSRF_KEY = "wzWf0CsZr3LfrgPVc9RqHFVUmyXsYT-k8hnGt41bMGU"
 
-db = Mkv(MONGO_URI)
+# Create an mongoKV instance using our URI
+db = Mkv(MONGO_URI, db_name=DB_NAME, collection_name="urls")
 
 
+# Our main app class
 class Shorty(App):
 
     def _generate_id(self, length: int = 8) -> str:
@@ -46,7 +54,23 @@ class Shorty(App):
         return await self._render_template("index.html", request=self.request)
 
 
-app = Shorty()
-app.middlewares.append(MongoRateLimitMiddleware(mongo_uri=MONGO_URI))
-app.middlewares.append(CSRFMiddleware(app=app, secret_key=CSRF_KEY))
+app = Shorty(session_backend=MkvSessionBackend(
+    mongo_uri=MONGO_URI, 
+    db_name=DB_NAME
+    )
+)
+app.middlewares.append(
+    MongoRateLimitMiddleware(
+        mongo_uri=MONGO_URI,
+        allowed_hosts=None,          # don't enforce host allowlist, change in prod
+        trust_proxy_headers=False,   # change in prod
+        require_cf_ray=False,
+    )
+)
+app.middlewares.append(
+    CSRFMiddleware(
+        app=app,
+        secret_key=CSRF_KEY
+    )
+)
 
diff --git a/examples/url_shortener/middlewares/rate_limit.py b/examples/url_shortener/middlewares/rate_limit.py
index 36bbd43..983f4a3 100644
--- a/examples/url_shortener/middlewares/rate_limit.py
+++ b/examples/url_shortener/middlewares/rate_limit.py
@@ -1,16 +1,36 @@
 from __future__ import annotations
 
+import ipaddress
 from datetime import datetime, timedelta
+from typing import Set
 
-from motor.motor_asyncio import AsyncIOMotorClient
+from pymongo import AsyncMongoClient, ReturnDocument
 from micropie import HttpMiddleware
 
 
+def _valid_ip(value: str | None) -> str | None:
+    try:
+        return str(ipaddress.ip_address(value.strip()))
+    except Exception:
+        return None
+
+
 class MongoRateLimitMiddleware(HttpMiddleware):
     """
-    Global MongoDB-based rate limiter.
+    Global MongoDB-based rate limiter (Heroku + Cloudflare safe).
+
+    - One document per client IP
+    - Fixed window counter
+    - Escalating temporary blocks
+    - Permanent block based on 24h violation history
+    - Fully atomic (single DB op per request)
+    - PyMongo Async API (no Motor)
+
+    Requires:
+    - MongoDB 4.2+ (aggregation pipeline updates)
     """
 
+    # --- rate config ---
     MAX_REQUESTS = 50
     WINDOW_SECONDS = 60
 
@@ -23,101 +43,267 @@ class MongoRateLimitMiddleware(HttpMiddleware):
     def __init__(
         self,
         mongo_uri: str,
-        db_name: str = "rate_limit",
-        collection_name: str = "list",
+        db_name: str = "vegy_security",
+        collection_name: str = "rate_limits_global",
+        *,
+        allowed_hosts: Set[str] | None = None,
+        trust_proxy_headers: bool = True,
+        require_cf_ray: bool = True,
     ):
-        self.client = AsyncIOMotorClient(mongo_uri)
+        self.client = AsyncMongoClient(mongo_uri)
         self.db = self.client[db_name]
         self.collection = self.db[collection_name]
 
-    async def before_request(self, request):
+        # Security / proxy config
+        self.allowed_hosts = allowed_hosts or set()
+        self.trust_proxy_headers = trust_proxy_headers
+        self.require_cf_ray = require_cf_ray
+
+    # ---------------------------------------------------------
+    # Real client IP resolution (for Cloudflare + Heroku or similar setups)
+    # ---------------------------------------------------------
+
+    def _client_ip(self, request) -> str:
+        headers = getattr(request, "headers", {}) or {}
+
+        # 1) Optional host allow-list (prevents origin bypass)
+        if self.allowed_hosts:
+            host = (headers.get("host") or "").split(":", 1)[0].lower()
+            if host and host not in self.allowed_hosts:
+                return "unknown"
+
+        # 2) Only trust proxy headers if allowed
+        can_trust = self.trust_proxy_headers
+        if can_trust and self.require_cf_ray:
+            can_trust = bool(headers.get("cf-ray"))
+
+        if can_trust:
+            # Cloudflare headers (best)
+            for h in ("cf-connecting-ip", "true-client-ip"):
+                ip = _valid_ip(headers.get(h))
+                if ip:
+                    return ip
+
+            # Standard proxy chain
+            xff = headers.get("x-forwarded-for")
+            if isinstance(xff, str):
+                ip = _valid_ip(xff.split(",", 1)[0])
+                if ip:
+                    return ip
+
+            # Fallback proxy header
+            ip = _valid_ip(headers.get("x-real-ip"))
+            if ip:
+                return ip
+
+        # 3) ASGI scope fallback (Heroku router)
         client = request.scope.get("client") or ("unknown", 0)
-        client_ip = client[0] or "unknown"
+        return _valid_ip(client[0]) or "unknown"
+
+    # ---------------------------------------------------------
+    # Middleware hook
+    # ---------------------------------------------------------
+
+    async def before_request(self, request):
+        client_ip = self._client_ip(request)
         now = datetime.utcnow()
 
         window_start_cutoff = now - timedelta(seconds=self.WINDOW_SECONDS)
         perma_window_cutoff = now - timedelta(hours=self.PERMA_WINDOW_HOURS)
 
         key = client_ip
-        doc = await self.collection.find_one({"_id": key})
 
-        if doc and doc.get("permanent_blocked"):
-            return {
-                "status_code": 403,
-                "body": f"Access permanently blocked for IP {client_ip}.",
-                "headers": [],
-            }
+        doc = await self.collection.find_one_and_update(
+            {"_id": key},
+            [
+                # 1) Baseline fields
+                {
+                    "$set": {
+                        "_id": key,
+                        "ip": client_ip,
+                        "count": {"$ifNull": ["$count", 0]},
+                        "window_start": {"$ifNull": ["$window_start", now]},
+                        "violations": {"$ifNull": ["$violations", 0]},
+                        "blocked_until": {"$ifNull": ["$blocked_until", None]},
+                        "permanent_blocked": {"$ifNull": ["$permanent_blocked", False]},
+                        "permanent_blocked_at": {"$ifNull": ["$permanent_blocked_at", None]},
+                        "violation_events": {"$ifNull": ["$violation_events", []]},
+                    }
+                },
 
-        if doc:
-            blocked_until = doc.get("blocked_until")
-            if isinstance(blocked_until, datetime) and now < blocked_until:
-                retry_after = max(0, int((blocked_until - now).total_seconds()))
-                return {
-                    "status_code": 429,
-                    "body": f"Too many requests from {client_ip}. Temporarily blocked.",
-                    "headers": [("Retry-After", str(retry_after))],
-                }
-
-        if (
-            not doc
-            or doc.get("window_start") is None
-            or doc["window_start"] < window_start_cutoff
-        ):
-            violations = doc.get("violations", 0) if doc else 0
-
-            await self.collection.replace_one(
-                {"_id": key},
+                # 2) Prune old violation events
                 {
-                    "_id": key,
-                    "ip": client_ip,
-                    "count": 1,
-                    "window_start": now,
-                    "violations": violations,
-                    "blocked_until": None,
-                    "permanent_blocked": doc.get("permanent_blocked", False) if doc else False,
-                    "permanent_blocked_at": doc.get("permanent_blocked_at") if doc else None,
-                    "violation_events": doc.get("violation_events", []) if doc else [],
+                    "$set": {
+                        "violation_events": {
+                            "$filter": {
+                                "input": "$violation_events",
+                                "as": "t",
+                                "cond": {"$gte": ["$$t", perma_window_cutoff]},
+                            }
+                        }
+                    }
                 },
-                upsert=True,
-            )
-            return None
 
-        count = int(doc.get("count", 0))
+                # 3) Are we currently blocked?
+                {
+                    "$set": {
+                        "_blocked_now": {
+                            "$or": [
+                                "$permanent_blocked",
+                                {
+                                    "$and": [
+                                        {"$ne": ["$blocked_until", None]},
+                                        {"$gt": ["$blocked_until", now]},
+                                    ]
+                                },
+                            ]
+                        }
+                    }
+                },
 
-        if count >= self.MAX_REQUESTS:
-            violations = int(doc.get("violations", 0)) + 1
+                # 4) Update window/count atomically (only if not blocked)
+                {
+                    "$set": {
+                        "_window_expired": {
+                            "$cond": [
+                                "$_blocked_now",
+                                False,
+                                {"$lt": ["$window_start", window_start_cutoff]},
+                            ]
+                        }
+                    }
+                },
+                {
+                    "$set": {
+                        "window_start": {
+                            "$cond": [
+                                "$_blocked_now",
+                                "$window_start",
+                                {"$cond": ["$_window_expired", now, "$window_start"]},
+                            ]
+                        },
+                        "count": {
+                            "$cond": [
+                                "$_blocked_now",
+                                "$count",
+                                {"$cond": ["$_window_expired", 1, {"$add": ["$count", 1]}]},
+                            ]
+                        },
+                    }
+                },
 
-            await self.collection.update_one(
-                {"_id": key},
+                # 5) Over limit?
                 {
-                    "$set": {"violations": violations},
-                    "$push": {"violation_events": now},
-                    "$pull": {"violation_events": {"$lt": perma_window_cutoff}},
+                    "$set": {
+                        "_over_limit": {
+                            "$and": [
+                                {"$not": "$_blocked_now"},
+                                {"$gt": ["$count", self.MAX_REQUESTS]},
+                            ]
+                        }
+                    }
                 },
-            )
 
-            doc = await self.collection.find_one({"_id": key})
-            events_last_24h = len((doc or {}).get("violation_events", []))
+                # 6) Record violation if over limit
+                {
+                    "$set": {
+                        "violations": {
+                            "$cond": ["$_over_limit", {"$add": ["$violations", 1]}, "$violations"]
+                        },
+                        "violation_events": {
+                            "$cond": [
+                                "$_over_limit",
+                                {"$concatArrays": ["$violation_events", [now]]},
+                                "$violation_events",
+                            ]
+                        },
+                    }
+                },
 
-            update = {}
+                # 7) Temporary block escalation
+                {
+                    "$set": {
+                        "blocked_until": {
+                            "$cond": [
+                                {
+                                    "$and": [
+                                        "$_over_limit",
+                                        {"$gte": ["$violations", self.BLOCK_AFTER_VIOLATIONS]},
+                                    ]
+                                },
+                                now + timedelta(seconds=self.BLOCK_FOR_SECONDS),
+                                "$blocked_until",
+                            ]
+                        }
+                    }
+                },
 
-            if violations >= self.BLOCK_AFTER_VIOLATIONS:
-                update["blocked_until"] = now + timedelta(seconds=self.BLOCK_FOR_SECONDS)
+                # 8) Permanent block escalation
+                {"$set": {"_events_24h": {"$size": "$violation_events"}}},
+                {
+                    "$set": {
+                        "permanent_blocked": {
+                            "$cond": [
+                                {
+                                    "$and": [
+                                        "$_over_limit",
+                                        {"$gte": ["$_events_24h", self.PERMA_BLOCK_AFTER]},
+                                    ]
+                                },
+                                True,
+                                "$permanent_blocked",
+                            ]
+                        },
+                        "permanent_blocked_at": {
+                            "$cond": [
+                                {
+                                    "$and": [
+                                        "$_over_limit",
+                                        {"$gte": ["$_events_24h", self.PERMA_BLOCK_AFTER]},
+                                        {"$eq": ["$permanent_blocked_at", None]},
+                                    ]
+                                },
+                                now,
+                                "$permanent_blocked_at",
+                            ]
+                        },
+                    }
+                },
+
+                # 9) Cleanup temp fields
+                {"$unset": ["_blocked_now", "_window_expired", "_over_limit", "_events_24h"]},
+            ],
+            upsert=True,
+            return_document=ReturnDocument.AFTER,
+            projection={"count": 1, "blocked_until": 1, "permanent_blocked": 1},
+        )
 
-            if events_last_24h >= self.PERMA_BLOCK_AFTER:
-                update["permanent_blocked"] = True
-                update["permanent_blocked_at"] = now
+        doc = doc or {}
 
-            if update:
-                await self.collection.update_one({"_id": key}, {"$set": update})
+        # --- responses ---
+        if doc.get("permanent_blocked"):
+            return {
+                "status_code": 403,
+                "body": f"Access permanently blocked for IP {client_ip}.",
+                "headers": [],
+            }
+
+        blocked_until = doc.get("blocked_until")
+        if isinstance(blocked_until, datetime) and now < blocked_until:
+            retry_after = max(0, int((blocked_until - now).total_seconds()))
+            return {
+                "status_code": 429,
+                "body": f"Too many requests from {client_ip}. Temporarily blocked.",
+                "headers": [("Retry-After", str(retry_after))],
+            }
 
+        if int(doc.get("count", 0)) > self.MAX_REQUESTS:
             return {
                 "status_code": 429,
                 "body": f"Rate limit exceeded for IP {client_ip}.",
                 "headers": [],
             }
 
-        await self.collection.update_one({"_id": key}, {"$inc": {"count": 1}})
         return None
 
     async def after_request(self, request, status_code, response_body, extra_headers):
diff --git a/examples/url_shortener/sessions/__init__.py b/examples/url_shortener/sessions/__init__.py
new file mode 100644
index 0000000..ea3f22e
--- /dev/null
+++ b/examples/url_shortener/sessions/__init__.py
@@ -0,0 +1,3 @@
+from .mongo_session import MkvSessionBackend
+
+__all__ = ["MkvSessionBackend"]
diff --git a/examples/url_shortener/sessions/mongo_session.py b/examples/url_shortener/sessions/mongo_session.py
new file mode 100644
index 0000000..1b47583
--- /dev/null
+++ b/examples/url_shortener/sessions/mongo_session.py
@@ -0,0 +1,94 @@
+import time
+from typing import Any, Dict, Optional
+
+from mongokv import Mkv
+from micropie import SessionBackend
+
+
+class MkvSessionBackend(SessionBackend):
+    """
+    Session backend backed by mongokv.Mkv.
+
+    Storage schema (per session_id):
+        key = session_id
+        value = {
+            "data": { ...session dict... },
+            "expires_at": <unix_epoch_seconds>
+        }
+
+    Notes:
+    - Expiration is enforced on load (lazy cleanup).
+    - save(..., {}, 0) deletes (matches MicroPie logout behavior).
+    """
+
+    def __init__(
+        self,
+        mongo_uri: str,
+        db_name: str,
+        collection_name: str = "sessions",
+        *,
+        key_prefix: str = "sess:",
+    ) -> None:
+        self.store = Mkv(mongo_uri, db_name=db_name, collection_name=collection_name)
+        self.key_prefix = key_prefix
+
+    def _k(self, session_id: str) -> str:
+        return f"{self.key_prefix}{session_id}"
+
+    async def load(self, session_id: str) -> Dict[str, Any]:
+        if not session_id:
+            return {}
+
+        key = self._k(session_id)
+
+        try:
+            payload = await self.store.get(key)
+        except KeyError:
+            return {}
+        except Exception:
+            # If you prefer, log this instead of swallowing.
+            return {}
+
+        if not isinstance(payload, dict):
+            # Corrupt/unexpected; treat as empty and delete
+            try:
+                await self.store.remove(key)
+            except Exception:
+                pass
+            return {}
+
+        expires_at = payload.get("expires_at")
+        if isinstance(expires_at, (int, float)) and time.time() > float(expires_at):
+            # Expired: delete and return empty
+            try:
+                await self.store.remove(key)
+            except Exception:
+                pass
+            return {}
+
+        data = payload.get("data", {})
+        return data if isinstance(data, dict) else {}
+
+    async def save(self, session_id: str, data: Dict[str, Any], timeout: int) -> None:
+        if not session_id:
+            return
+
+        key = self._k(session_id)
+
+        # MicroPie uses save(session_id, {}, 0) for logout/delete
+        if not data or timeout <= 0:
+            try:
+                await self.store.remove(key)
+            except Exception:
+                pass
+            return
+
+        expires_at = time.time() + int(timeout)
+
+        payload = {
+            "data": data,
+            "expires_at": expires_at,
+        }
+
+        await self.store.set(key, payload)
+