patx/projectpay

ensure async stripe and mongodb operations

Commit b2fdae5 · patx · 2026-06-30T09:25:03-04:00

Changeset
b2fdae50a67f15f286af6bacccc8bcfa831b09dc
Parents
f77eab7db65b6ab0ff76fc2293e53a05e9c9c880

View source at this commit

Comments

No comments yet.

Log in to comment

Diff

diff --git a/README.md b/README.md
index e90f443..6128bf0 100644
--- a/README.md
+++ b/README.md
@@ -99,11 +99,11 @@ ProjectPay is a small web app for creating project payment pages and collecting
    https://your-domain.example/webhook/stripe
    ```
 
-3. Subscribe the webhook endpoint to the `checkout.session.completed` event.
+3. Subscribe the webhook endpoint to the `checkout.session.completed` event. If you enable delayed payment methods, also subscribe to `checkout.session.async_payment_succeeded`.
 4. Copy the webhook signing secret into `STRIPE_WEBHOOK_SECRET`.
 5. Set `STRIPE_SECRET_KEY` to your Stripe secret key.
 
-ProjectPay records payments from Stripe webhooks. The customer return page may show a thank-you message immediately after Checkout redirects back, but the project balance is updated only after the signed webhook is received and processed.
+ProjectPay records payments from signed Stripe webhooks. The customer return page may show a thank-you message immediately after Checkout redirects back, but the project balance is updated only after a successful payment webhook is received and processed.
 
 For local webhook testing, use the Stripe CLI:
 
diff --git a/app.py b/app.py
index ab7b37d..05c18ed 100644
--- a/app.py
+++ b/app.py
@@ -11,7 +11,7 @@ import stripe
 from bson import ObjectId
 from bson.errors import InvalidId
 from markupsafe import Markup, escape
-from pymongo import MongoClient, ReturnDocument
+from pymongo import AsyncMongoClient, ReturnDocument
 from pymongo.errors import DuplicateKeyError
 
 from micropie import App
@@ -19,6 +19,10 @@ from micropie import App
 
 PAGE_SIZE = 20
 ADMIN_COOKIE = "invoice_admin"
+SUCCESSFUL_CHECKOUT_EVENTS = {
+    "checkout.session.completed",
+    "checkout.session.async_payment_succeeded",
+}
 
 
 def load_dotenv(path: str = ".env") -> None:
@@ -112,12 +116,13 @@ class InvoiceApp(App):
         self.mongo_uri = os.getenv("MONGODB_URI", "mongodb://localhost:27017")
         self.mongo_db_name = os.getenv("MONGODB_DB", "invoice_maker")
         timeout_ms = int(os.getenv("MONGODB_SERVER_SELECTION_TIMEOUT_MS", "5000"))
-        self.client = MongoClient(
+        self.client = AsyncMongoClient(
             self.mongo_uri,
             serverSelectionTimeoutMS=timeout_ms,
             connect=False,
         )
         self._indexes_ready = False
+        self._index_lock = asyncio.Lock()
 
         self.admin_password = os.getenv("ADMIN_PASSWORD", "admin")
         self.admin_cookie_secret = (
@@ -136,36 +141,56 @@ class InvoiceApp(App):
         self.stripe_minimum_amount_cents = int(
             os.getenv("STRIPE_MINIMUM_AMOUNT_CENTS", "100") or "100"
         )
+        self.stripe_http_client = None
         if self.stripe_secret_key:
             stripe.api_key = self.stripe_secret_key
+            self.stripe_http_client = stripe.HTTPXClient()
+            stripe.default_http_client = self.stripe_http_client
+
+        self.shutdown_handlers.append(self._shutdown_clients)
 
         if self.env is not None:
             self.env.filters["money"] = cents_to_money
             self.env.filters["star_emphasis"] = format_star_emphasis
             self.env.globals["money"] = cents_to_money
 
-    @property
-    def db(self):
-        self._ensure_indexes()
+    async def _db(self):
+        await self._ensure_indexes()
         return self.client[self.mongo_db_name]
 
-    def _ensure_indexes(self) -> None:
+    async def _ensure_indexes(self) -> None:
         if self._indexes_ready:
             return
-        db = self.client[self.mongo_db_name]
-        db.projects.create_index("project_number", unique=True)
-        db.projects.create_index("share_token", unique=True)
-        db.projects.create_index("customer_name")
-        db.projects.create_index("created_at")
-        db.payments.create_index("project_id")
-        db.payments.create_index("webhook_id", unique=True, sparse=True)
-        db.payments.create_index("stripe_event_id", unique=True, sparse=True)
-        db.payments.create_index("stripe_payment_intent_id", unique=True, sparse=True)
-        db.payments.create_index("stripe_checkout_session_id", unique=True, sparse=True)
-        db.checkout_sessions.create_index("session_id", unique=True, sparse=True)
-        db.checkout_sessions.create_index("project_id")
-        db.webhook_events.create_index("webhook_id", unique=True, sparse=True)
-        self._indexes_ready = True
+        async with self._index_lock:
+            if self._indexes_ready:
+                return
+            db = self.client[self.mongo_db_name]
+            await db.projects.create_index("project_number", unique=True)
+            await db.projects.create_index("share_token", unique=True)
+            await db.projects.create_index("customer_name")
+            await db.projects.create_index("created_at")
+            await db.payments.create_index("project_id")
+            await db.payments.create_index("webhook_id", unique=True, sparse=True)
+            await db.payments.create_index("stripe_event_id", unique=True, sparse=True)
+            await db.payments.create_index(
+                "stripe_payment_intent_id", unique=True, sparse=True
+            )
+            await db.payments.create_index(
+                "stripe_checkout_session_id", unique=True, sparse=True
+            )
+            await db.checkout_sessions.create_index(
+                "session_id", unique=True, sparse=True
+            )
+            await db.checkout_sessions.create_index("project_id")
+            await db.webhook_events.create_index(
+                "webhook_id", unique=True, sparse=True
+            )
+            self._indexes_ready = True
+
+    async def _shutdown_clients(self) -> None:
+        if self.stripe_http_client is not None:
+            await self.stripe_http_client.close_async()
+        await self.client.close()
 
     async def __call__(self, scope, receive, send) -> None:
         if (
@@ -195,9 +220,7 @@ class InvoiceApp(App):
             key.decode("utf-8", "replace").lower(): value.decode("utf-8", "replace")
             for key, value in scope.get("headers", [])
         }
-        status, body = await asyncio.to_thread(
-            self._handle_stripe_webhook, raw_body, headers
-        )
+        status, body = await self._handle_stripe_webhook(raw_body, headers)
         await self._send_response(
             send,
             status,
@@ -205,7 +228,7 @@ class InvoiceApp(App):
             [("Content-Type", "application/json")],
         )
 
-    def _handle_stripe_webhook(
+    async def _handle_stripe_webhook(
         self, raw_body: bytes, headers: Dict[str, str]
     ) -> Tuple[int, Dict[str, Any]]:
         if not self.stripe_webhook_secret:
@@ -228,43 +251,58 @@ class InvoiceApp(App):
             return 400, {"error": "Webhook event id is missing"}
         event_type = payload.get("type", "")
         created_at = now_utc()
+        db = await self._db()
+
+        existing_event = await db.webhook_events.find_one(
+            {"webhook_id": webhook_id},
+            {"processed": 1},
+        )
+        if existing_event and existing_event.get("processed") is True:
+            return 200, {"received": True, "duplicate": True}
 
         try:
-            result = self.db.webhook_events.update_one(
+            await db.webhook_events.update_one(
                 {"webhook_id": webhook_id},
                 {
                     "$setOnInsert": {
                         "webhook_id": webhook_id,
+                        "created_at": created_at,
+                    },
+                    "$set": {
                         "provider": "stripe",
                         "event_type": event_type,
                         "payload": payload,
                         "processed": False,
-                        "created_at": created_at,
-                    }
+                        "last_received_at": created_at,
+                    },
+                    "$inc": {"attempts": 1},
                 },
                 upsert=True,
             )
         except DuplicateKeyError:
-            return 200, {"received": True, "duplicate": True}
-
-        if result.upserted_id is None:
-            return 200, {"received": True, "duplicate": True}
+            existing_event = await db.webhook_events.find_one(
+                {"webhook_id": webhook_id},
+                {"processed": 1},
+            )
+            if existing_event and existing_event.get("processed") is True:
+                return 200, {"received": True, "duplicate": True}
 
         try:
-            outcome = self._process_stripe_webhook(payload, webhook_id)
-            self.db.webhook_events.update_one(
+            outcome = await self._process_stripe_webhook(payload, webhook_id)
+            await db.webhook_events.update_one(
                 {"webhook_id": webhook_id},
                 {
                     "$set": {
                         "processed": True,
                         "processed_at": now_utc(),
                         "outcome": outcome,
-                    }
+                    },
+                    "$unset": {"error": ""},
                 },
             )
             return 200, {"received": True, **outcome}
         except Exception as exc:
-            self.db.webhook_events.update_one(
+            await db.webhook_events.update_one(
                 {"webhook_id": webhook_id},
                 {
                     "$set": {
@@ -276,20 +314,20 @@ class InvoiceApp(App):
             )
             return 500, {"error": "Webhook processing failed"}
 
-    def _process_stripe_webhook(
+    async def _process_stripe_webhook(
         self, payload: Dict[str, Any], webhook_id: str
     ) -> Dict[str, Any]:
         event_type = payload.get("type", "")
         data = payload.get("data") or {}
         session = as_plain_dict(data.get("object"))
 
-        if event_type != "checkout.session.completed":
-            return {"ignored": True, "reason": "not a completed checkout session"}
+        if event_type not in SUCCESSFUL_CHECKOUT_EVENTS:
+            return {"ignored": True, "reason": "not a successful checkout session"}
 
         if str(session.get("payment_status") or "").lower() != "paid":
             return {"ignored": True, "reason": "checkout session is not paid"}
 
-        project = self._project_from_stripe_session(session)
+        project = await self._project_from_stripe_session(session)
         if project is None:
             return {"ignored": True, "reason": "project metadata not found"}
 
@@ -327,20 +365,18 @@ class InvoiceApp(App):
         elif checkout_session_id:
             query = {"stripe_checkout_session_id": checkout_session_id}
 
+        db = await self._db()
         try:
-            insert_result = self.db.payments.update_one(
+            insert_result = await db.payments.update_one(
                 query,
                 {"$setOnInsert": payment_doc},
                 upsert=True,
             )
         except DuplicateKeyError:
-            return {"received": True, "duplicate": True}
-
-        if insert_result.upserted_id is None:
-            return {"received": True, "duplicate": True}
+            insert_result = None
 
         if checkout_session_id:
-            self.db.checkout_sessions.update_one(
+            await db.checkout_sessions.update_one(
                 {"session_id": checkout_session_id},
                 {
                     "$set": {
@@ -352,10 +388,12 @@ class InvoiceApp(App):
                 },
             )
 
-        self._refresh_project_status(project["_id"])
+        await self._refresh_project_status(project["_id"])
+        if insert_result is None or insert_result.upserted_id is None:
+            return {"received": True, "duplicate": True}
         return {"received": True, "payment_recorded": True}
 
-    def _project_from_stripe_session(
+    async def _project_from_stripe_session(
         self, session: Dict[str, Any]
     ) -> Optional[Dict[str, Any]]:
         metadata = as_plain_dict(session.get("metadata"))
@@ -365,26 +403,27 @@ class InvoiceApp(App):
             or session.get("client_reference_id")
         )
         share_token = metadata.get("share_token") or metadata.get("shareToken")
+        db = await self._db()
 
         if project_id:
             oid = object_id(str(project_id))
             if oid is not None:
-                project = self.db.projects.find_one({"_id": oid})
+                project = await db.projects.find_one({"_id": oid})
                 if project:
                     return project
 
         if share_token:
-            project = self.db.projects.find_one({"share_token": str(share_token)})
+            project = await db.projects.find_one({"share_token": str(share_token)})
             if project:
                 return project
 
         checkout_session_id = session.get("id")
         if checkout_session_id:
-            session = self.db.checkout_sessions.find_one(
+            checkout_doc = await db.checkout_sessions.find_one(
                 {"session_id": str(checkout_session_id)}
             )
-            if session:
-                return self.db.projects.find_one({"_id": session["project_id"]})
+            if checkout_doc:
+                return await db.projects.find_one({"_id": checkout_doc["project_id"]})
         return None
 
     def _stripe_metadata(self, project: Dict[str, Any]) -> Dict[str, str]:
@@ -394,14 +433,14 @@ class InvoiceApp(App):
             "share_token": str(project["share_token"]),
         }
 
-    def _create_checkout_session(self, project: Dict[str, Any]) -> Dict[str, Any]:
+    async def _create_checkout_session(self, project: Dict[str, Any]) -> Dict[str, Any]:
         if not self.stripe_secret_key:
             raise RuntimeError("STRIPE_SECRET_KEY is not configured")
         if not self.stripe_product_id:
             raise RuntimeError("STRIPE_PRODUCT_ID is not configured")
 
         stripe.api_key = self.stripe_secret_key
-        summary = self._summarize_project(project)
+        summary = await self._summarize_project(project)
         balance_cents = int(summary["balance_cents"])
         if balance_cents <= 0:
             raise RuntimeError("Project is already paid in full")
@@ -410,7 +449,7 @@ class InvoiceApp(App):
         metadata = self._stripe_metadata(project)
         minimum_amount_cents = max(1, min(self.stripe_minimum_amount_cents, balance_cents))
 
-        price = stripe.Price.create(
+        price = await stripe.Price.create_async(
             currency=self.currency.lower(),
             product=self.stripe_product_id,
             custom_unit_amount={
@@ -438,7 +477,7 @@ class InvoiceApp(App):
         if project.get("customer_email"):
             payload["customer_email"] = project["customer_email"]
 
-        response = stripe.checkout.Session.create(**payload)
+        response = await stripe.checkout.Session.create_async(**payload)
         response_dict = as_plain_dict(response)
         checkout_url = response_dict.get("url") or getattr(response, "url", None)
         session_id = response_dict.get("id") or getattr(response, "id", None)
@@ -492,8 +531,9 @@ class InvoiceApp(App):
     def _public_project_url(self, project: Dict[str, Any]) -> str:
         return f"{self._base_url()}/p/{project['share_token']}"
 
-    def _next_project_number(self) -> str:
-        counter = self.db.counters.find_one_and_update(
+    async def _next_project_number(self) -> str:
+        db = await self._db()
+        counter = await db.counters.find_one_and_update(
             {"_id": "project_number"},
             {"$inc": {"value": 1}},
             upsert=True,
@@ -504,17 +544,19 @@ class InvoiceApp(App):
     def _project_total(self, project: Dict[str, Any]) -> int:
         return sum(int(item.get("amount_cents", 0)) for item in project.get("line_items", []))
 
-    def _project_paid(self, project_id: ObjectId) -> int:
+    async def _project_paid(self, project_id: ObjectId) -> int:
         pipeline = [
             {"$match": {"project_id": project_id, "status": "succeeded"}},
             {"$group": {"_id": "$project_id", "total": {"$sum": "$amount_cents"}}},
         ]
-        rows = list(self.db.payments.aggregate(pipeline))
+        db = await self._db()
+        cursor = await db.payments.aggregate(pipeline)
+        rows = await cursor.to_list(length=1)
         return int(rows[0]["total"]) if rows else 0
 
-    def _summarize_project(self, project: Dict[str, Any]) -> Dict[str, Any]:
+    async def _summarize_project(self, project: Dict[str, Any]) -> Dict[str, Any]:
         total = self._project_total(project)
-        paid = self._project_paid(project["_id"])
+        paid = await self._project_paid(project["_id"])
         balance = max(total - paid, 0)
         status = "paid" if total > 0 and balance == 0 else project.get("status", "open")
         return {
@@ -528,7 +570,7 @@ class InvoiceApp(App):
             "share_url": self._public_project_url(project),
         }
 
-    def _find_projects(self, q: str, page: int) -> Tuple[List[Dict[str, Any]], bool]:
+    async def _find_projects(self, q: str, page: int) -> Tuple[List[Dict[str, Any]], bool]:
         query: Dict[str, Any] = {}
         q = (q or "").strip()
         if q:
@@ -543,14 +585,12 @@ class InvoiceApp(App):
             query = {"$or": clauses}
 
         skip = max(page - 1, 0) * PAGE_SIZE
-        docs = list(
-            self.db.projects.find(query)
-            .sort("created_at", -1)
-            .skip(skip)
-            .limit(PAGE_SIZE + 1)
-        )
+        db = await self._db()
+        cursor = db.projects.find(query).sort("created_at", -1).skip(skip).limit(PAGE_SIZE + 1)
+        docs = await cursor.to_list(length=PAGE_SIZE + 1)
         has_more = len(docs) > PAGE_SIZE
-        return [self._summarize_project(doc) for doc in docs[:PAGE_SIZE]], has_more
+        projects = [await self._summarize_project(doc) for doc in docs[:PAGE_SIZE]]
+        return projects, has_more
 
     def _parse_line_items(self) -> Tuple[List[Dict[str, Any]], List[str]]:
         names = self.request.body_params.get("item_name", [])
@@ -618,14 +658,15 @@ class InvoiceApp(App):
         project["errors"] = sorted(set(errors))
         return project
 
-    def _refresh_project_status(self, project_id: ObjectId) -> None:
-        project = self.db.projects.find_one({"_id": project_id})
+    async def _refresh_project_status(self, project_id: ObjectId) -> None:
+        db = await self._db()
+        project = await db.projects.find_one({"_id": project_id})
         if not project:
             return
         total = self._project_total(project)
-        paid = self._project_paid(project_id)
+        paid = await self._project_paid(project_id)
         status = "paid" if total > 0 and paid >= total else "open"
-        self.db.projects.update_one(
+        await db.projects.update_one(
             {"_id": project_id},
             {"$set": {"status": status, "updated_at": now_utc()}},
         )
@@ -642,7 +683,7 @@ class InvoiceApp(App):
             return redirect
         page = int(self.request.query("page", "1") or "1")
         q = self.request.query("q", "") or ""
-        projects, has_more = self._find_projects(q, page)
+        projects, has_more = await self._find_projects(q, page)
         return await self._render(
             "index.html",
             projects=projects,
@@ -656,7 +697,7 @@ class InvoiceApp(App):
             return redirect
         page = int(self.request.query("page", "1") or "1")
         q = self.request.query("q", "") or ""
-        projects, has_more = self._find_projects(q, page)
+        projects, has_more = await self._find_projects(q, page)
         return await self._render(
             "_project_rows.html",
             projects=projects,
@@ -704,8 +745,9 @@ class InvoiceApp(App):
             return await self._render("form.html", project=project, mode="new")
 
         now = now_utc()
+        project_number = await self._next_project_number()
         project_doc = {
-            "project_number": self._next_project_number(),
+            "project_number": project_number,
             "share_token": secrets.token_urlsafe(24),
             "customer_name": project["customer_name"],
             "customer_email": project["customer_email"],
@@ -715,14 +757,16 @@ class InvoiceApp(App):
             "created_at": now,
             "updated_at": now,
         }
-        result = self.db.projects.insert_one(project_doc)
+        db = await self._db()
+        result = await db.projects.insert_one(project_doc)
         return self._redirect(f"/project/{result.inserted_id}")
 
     async def edit(self, project_id: str) -> Any:
         if redirect := self._require_admin():
             return redirect
         oid = object_id(project_id)
-        project = self.db.projects.find_one({"_id": oid}) if oid else None
+        db = await self._db()
+        project = await db.projects.find_one({"_id": oid}) if oid else None
         if not project:
             return 404, "Project not found"
 
@@ -730,7 +774,7 @@ class InvoiceApp(App):
             updated = self._project_from_form(project)
             if updated["errors"]:
                 return await self._render("form.html", project=updated, mode="edit")
-            self.db.projects.update_one(
+            await db.projects.update_one(
                 {"_id": project["_id"]},
                 {
                     "$set": {
@@ -742,10 +786,10 @@ class InvoiceApp(App):
                     }
                 },
             )
-            self._refresh_project_status(project["_id"])
+            await self._refresh_project_status(project["_id"])
             return self._redirect(f"/project/{project['_id']}")
 
-        project = self._summarize_project(project)
+        project = await self._summarize_project(project)
         project["errors"] = []
         return await self._render("form.html", project=project, mode="edit")
 
@@ -753,13 +797,14 @@ class InvoiceApp(App):
         if redirect := self._require_admin():
             return redirect
         oid = object_id(project_id)
-        project = self.db.projects.find_one({"_id": oid}) if oid else None
+        db = await self._db()
+        project = await db.projects.find_one({"_id": oid}) if oid else None
         if not project:
             return 404, "Project not found"
-        summary = self._summarize_project(project)
-        payments = list(
-            self.db.payments.find({"project_id": project["_id"]}).sort("created_at", -1)
-        )
+        summary = await self._summarize_project(project)
+        payments = await db.payments.find({"project_id": project["_id"]}).sort(
+            "created_at", -1
+        ).to_list(length=None)
         for payment in payments:
             payment["id"] = str(payment["_id"])
         return await self._render(
@@ -772,32 +817,35 @@ class InvoiceApp(App):
         if redirect := self._require_admin():
             return redirect
         oid = object_id(project_id)
-        project = self.db.projects.find_one({"_id": oid}) if oid else None
+        db = await self._db()
+        project = await db.projects.find_one({"_id": oid}) if oid else None
         if not project:
             return 404, "Project not found"
         if self.request.method != "POST":
             return self._redirect(f"/project/{project_id}")
 
+        payment_docs = await db.payments.find(
+            {"project_id": project["_id"], "webhook_id": {"$exists": True}},
+            {"webhook_id": 1},
+        ).to_list(length=None)
         payment_webhook_ids = [
-            payment["webhook_id"]
-            for payment in self.db.payments.find(
-                {"project_id": project["_id"], "webhook_id": {"$exists": True}},
-                {"webhook_id": 1},
-            )
-            if payment.get("webhook_id")
+            payment["webhook_id"] for payment in payment_docs if payment.get("webhook_id")
         ]
         if payment_webhook_ids:
-            self.db.webhook_events.delete_many({"webhook_id": {"$in": payment_webhook_ids}})
-        self.db.checkout_sessions.delete_many({"project_id": project["_id"]})
-        self.db.payments.delete_many({"project_id": project["_id"]})
-        self.db.projects.delete_one({"_id": project["_id"]})
+            await db.webhook_events.delete_many(
+                {"webhook_id": {"$in": payment_webhook_ids}}
+            )
+        await db.checkout_sessions.delete_many({"project_id": project["_id"]})
+        await db.payments.delete_many({"project_id": project["_id"]})
+        await db.projects.delete_one({"_id": project["_id"]})
         return self._redirect("/")
 
     async def p(self, share_token: str) -> Any:
-        project = self.db.projects.find_one({"share_token": share_token})
+        db = await self._db()
+        project = await db.projects.find_one({"share_token": share_token})
         if not project:
             return 404, "Project not found"
-        summary = self._summarize_project(project)
+        summary = await self._summarize_project(project)
         checkout_state = self.request.query("checkout", "") or ""
         return await self._render(
             "public_project.html",
@@ -806,10 +854,11 @@ class InvoiceApp(App):
         )
 
     async def pay(self, share_token: str) -> Any:
-        project = self.db.projects.find_one({"share_token": share_token})
+        db = await self._db()
+        project = await db.projects.find_one({"share_token": share_token})
         if not project:
             return 404, "Project not found"
-        summary = self._summarize_project(project)
+        summary = await self._summarize_project(project)
 
         if self.request.method != "POST":
             return self._redirect(f"/p/{share_token}")
@@ -818,7 +867,7 @@ class InvoiceApp(App):
             return self._redirect(f"/p/{share_token}")
 
         try:
-            checkout = await asyncio.to_thread(self._create_checkout_session, project)
+            checkout = await self._create_checkout_session(project)
         except Exception as exc:
             return await self._render(
                 "public_project.html",
@@ -840,7 +889,7 @@ class InvoiceApp(App):
             checkout_doc["stripe_price_id"] = checkout["stripe_price_id"]
         if checkout.get("balance_cents"):
             checkout_doc["balance_cents_at_creation"] = checkout["balance_cents"]
-        self.db.checkout_sessions.insert_one(checkout_doc)
+        await db.checkout_sessions.insert_one(checkout_doc)
         return self._redirect(checkout["checkout_url"])
 
 
diff --git a/requirements.txt b/requirements.txt
index e3c5a0a..6166909 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,3 +1,4 @@
 micropie[all]
-pymongo
-stripe
+pymongo>=4.17.0
+stripe>=15.1.0
+httpx