Fix sub-app body parsing bug by ensuring Request object inherits scope's body_params and body_parsed, preventing redundant parsing in sub-app.

Commit a05a10e · patx · 2025-08-20T22:23:10-04:00

Changeset
a05a10e3d8ca2a39d8c1e2bb8f6d07919ccd3787
Parents
6594c0a65fc340cd868581a98279c7f4ca7173d5

View source at this commit

Comments

No comments yet.

Log in to comment

Diff

diff --git a/.gitignore b/.gitignore
index ddf19ea..54de266 100644
--- a/.gitignore
+++ b/.gitignore
@@ -13,3 +13,4 @@ dist/
 build/
 *.egg
 .pypirc
+.todo
diff --git a/examples/middleware/subapp.py b/examples/middleware/subapp.py
index b924561..f0d5d15 100644
--- a/examples/middleware/subapp.py
+++ b/examples/middleware/subapp.py
@@ -64,18 +64,15 @@ class ApiApp(App):
         return f"You have visited {visits} times."
 
     async def login(self):
-        if self.request.method == "GET":
-            csrf_token = self.request.session.get("csrf_token", "")
-            print(f"Rendering form with CSRF token: {csrf_token}")
-            return f"""<form method="POST" action="/submit">
-                <input type="hidden" name="csrf_token" value="{escape(csrf_token)}">
-                <input type="text" name="name">
-                <button type="submit">Submit</button>
-                </form>"""
-        if self.request.method == "POST":
-            name = self.request.body_params.get("name", ["World"])[0]
-            return f"Hello {name}"
-        return None
+        csrf_token = self.request.session.get("csrf_token", "")
+        return f"""<form method="POST" action="/api/plogin">
+            <input type="hidden" name="csrf_token" value="{escape(csrf_token)}">
+            <input type="text" name="name">
+            <button type="submit">Submit</button>
+            </form>"""
+        
+    async def plogin(self, name):
+        return f"Hello {name}"
 
 # Define a Middleware to Mount the Sub-App
 class SubAppMiddleware(HttpMiddleware):
diff --git a/micropie.py b/micropie.py
index e6ed820..051d39b 100644
--- a/micropie.py
+++ b/micropie.py
@@ -102,14 +102,15 @@ class Request:
         self.method: str = scope.get("method", "")
         self.path_params: List[str] = []
         self.query_params: Dict[str, List[str]] = {}
-        self.body_params: Dict[str, List[str]] = {}
-        self.get_json: Any = {}
-        self.session: Dict[str, Any] = {}
-        self.files: Dict[str, Any] = {}
+        self.body_params: Dict[str, List[str]] = scope.get("body_params", {})
+        self.get_json: Any = scope.get("get_json", {})
+        self.session: Dict[str, Any] = scope.get("session", {})
+        self.files: Dict[str, Any] = scope.get("files", {})
         self.headers: Dict[str, str] = {
             k.decode("utf-8", errors="replace").lower(): v.decode("utf-8", errors="replace")
             for k, v in scope.get("headers", [])
         }
+        self.body_parsed: bool = scope.get("body_parsed", False)
 
 class WebSocketRequest(Request):
     """Represents a WebSocket request in the MicroPie framework."""
@@ -433,18 +434,14 @@ class App:
             Wait until a multipart file field named `name` is available in request.files,
             or until the background parse_task (if any) completes.
             """
-            # Fast path
             if name in request.files:
                 return request.files[name]
-            # If no background parse, nothing to wait for
             if parse_task is None:
                 return None
-            # Wait until the field appears or parse ends
             while True:
                 if name in request.files:
                     return request.files[name]
                 if parse_task.done():
-                    # parsing finished and field never arrived
                     return None
                 await asyncio.sleep(0)
 
@@ -452,11 +449,11 @@ class App:
             # Parse query/cookies/session
             request.query_params = parse_qs(scope.get("query_string", b"").decode("utf-8", "ignore"))
             cookies = self._parse_cookies(request.headers.get("cookie", ""))
-            request.session = await self.session_backend.load(cookies.get("session_id", "")) or {}
+            request.session = scope.get("session", await self.session_backend.load(cookies.get("session_id", "")) or {})
+            content_type = request.headers.get("content-type", "")
 
             # Body parsing setup
-            content_type = request.headers.get("content-type", "")
-            if request.method in ("POST", "PUT", "PATCH"):
+            if request.method in ("POST", "PUT", "PATCH") and not request.body_parsed and not request.body_params:
                 if "multipart/form-data" in content_type:
                     if not MULTIPART_INSTALLED:
                         print("For multipart form data support install 'multipart'.")
@@ -466,7 +463,6 @@ class App:
                     if not boundary_match:
                         await self._send_response(send, 400, "400 Bad Request: Missing boundary")
                         return
-                    # Start TRUE streaming parser in background, populate request.* live
                     parse_task = asyncio.create_task(
                         self._parse_multipart_into_request(
                             receive,
@@ -475,25 +471,37 @@ class App:
                             file_queue_maxsize=2048,
                         )
                     )
+                    try:
+                        await parse_task
+                        request.body_parsed = True
+                    except Exception as e:
+                        await self._send_response(send, 500, "500 Internal Server Error: Multipart parsing failed")
+                        return
                 else:
-                    # Small-body buffering for JSON / x-www-form-urlencoded
                     body_data = bytearray()
-                    while True:
-                        msg: Dict[str, Any] = await receive()
-                        if chunk := msg.get("body", b""):
-                            body_data += chunk
-                        if not msg.get("more_body"):
-                            break
+                    try:
+                        async with asyncio.timeout(5):  # Timeout after 5 seconds
+                            while True:
+                                msg = await receive()
+                                if chunk := msg.get("body", b""):
+                                    body_data += chunk
+                                if not msg.get("more_body"):
+                                    break
+                    except asyncio.TimeoutError:
+                        await self._send_response(send, 408, "408 Request Timeout: Failed to receive body")
+                        return
+                    decoded_body = body_data.decode("utf-8", "ignore")
                     if "application/json" in content_type:
                         try:
-                            request.get_json = json.loads(body_data.decode("utf-8"))
+                            request.get_json = json.loads(decoded_body)
                             if isinstance(request.get_json, dict):
                                 request.body_params = {k: [str(v)] for k, v in request.get_json.items()}
-                        except Exception:
+                        except Exception as e:
                             await self._send_response(send, 400, "400 Bad Request: Bad JSON")
                             return
                     else:
-                        request.body_params = parse_qs(body_data.decode("utf-8", "ignore"))
+                        request.body_params = parse_qs(decoded_body)
+                    request.body_parsed = True
 
             # HTTP middlewares (before)
             for mw in self.middlewares:
@@ -511,7 +519,17 @@ class App:
                 new_scope = dict(scope)
                 new_scope["path"] = request._subapp_path
                 new_scope["root_path"] = scope.get("root_path", "") + "/" + self.middlewares[0].mount_path
-                await request._subapp(new_scope, receive, send)
+                new_scope["body_params"] = request.body_params
+                new_scope["body_parsed"] = request.body_parsed
+                new_scope["get_json"] = getattr(request, "get_json", {})
+                new_scope["files"] = request.files
+                new_scope["session"] = request.session
+                # Create a receive callable that returns an empty body if already parsed
+                async def subapp_receive():
+                    if request.body_parsed or request.body_params:
+                        return {"type": "http.request", "body": b"", "more_body": False}
+                    return await receive()
+                await request._subapp(new_scope, subapp_receive, send)
                 return
 
             # Routing
@@ -570,7 +588,6 @@ class App:
                 elif param.name in request.files:
                     param_value = request.files[param.name]
                 elif "multipart/form-data" in content_type:
-                    # If multipart and file field not yet available, wait for it to show up
                     param_value = await _await_file_param(param.name)
                     if param_value is None and param.default is param.empty:
                         status_code = 400
@@ -633,7 +650,7 @@ class App:
                         result.get("headers", extra_headers)
                     )
 
-            # --- BEGIN PATCH FOR SSE CLIENT DISCONNECT ---
+            # Handle async generators (e.g., SSE)
             if hasattr(response_body, "__aiter__"):
                 sanitized_headers: List[Tuple[str, str]] = []
                 for k, v in extra_headers:
@@ -696,13 +713,12 @@ class App:
                         except asyncio.CancelledError:
                             pass
                 return
-            # --- END PATCH FOR SSE CLIENT DISCONNECT ---
             else:
                 await self._send_response(send, status_code, response_body, extra_headers)
 
         finally:
             current_request.reset(token)
-
+            
     async def _asgi_app_websocket(
         self,
         scope: Dict[str, Any],