make sure multipart parsing stops on middleware short circuit

Commit f928599 · patx · 2025-08-20T23:14:53-04:00

Changeset
f92859974feed9c9fd828752b829954a2d44ebca
Parents
6e3459ca06ff4eac709b2347fc6539bc75dd824c

View source at this commit

Comments

No comments yet.

Log in to comment

Diff

diff --git a/docs/release_notes.md b/docs/release_notes.md
index 75db80f..bd462e8 100644
--- a/docs/release_notes.md
+++ b/docs/release_notes.md
@@ -1,6 +1,7 @@
 [![Logo](https://patx.github.io/micropie/logo.png)](https://patx.github.io/micropie)
 
 ## Releases Notes
+- **[0.22](https://github.com/patx/micropie/releases/tag/v0.23)** - Bug fix release. Make sure background multipart parsing stops when the request is terminated by middleware
 - **[0.22](https://github.com/patx/micropie/releases/tag/v0.22)** - Bug fix release. Fix sub-app body parsing bug by ensuring `Request` object inherits scope's `body_params` and `body_parsed`, preventing redundant parsing in sub-app
 - **[0.21](https://github.com/patx/micropie/releases/tag/v0.21)** - Bug fix release. Make sure index route handler can handle path params
 - **[0.20](https://github.com/patx/micropie/releases/tag/v0.20)** - Enable concurrent multipart parsing and file writing with bounded queues
diff --git a/examples/middleware/upload.py b/examples/middleware/upload.py
index f6e9e58..eb7aa50 100644
--- a/examples/middleware/upload.py
+++ b/examples/middleware/upload.py
@@ -1,9 +1,5 @@
-"""
-This file demonstrates how to use a middleware to check file upload sizes
-before the request body is processed by the multipart parser.
-"""
-
 from micropie import App, HttpMiddleware
+import asyncio
 
 MAX_UPLOAD_SIZE = 100 * 1024 * 1024  # 100MB
 
@@ -12,13 +8,26 @@ class MaxUploadSizeMiddleware(HttpMiddleware):
         # Check if we're dealing with a POST, PUT, or PATCH request
         if request.method in ("POST", "PUT", "PATCH"):
             content_length = request.headers.get("content-length")
-            # Make sure the file is not too large
-            if int(content_length) > MAX_UPLOAD_SIZE:
+            # Ensure Content-Length is present and valid
+            if content_length is None:
+                return {
+                    "status_code": 400,
+                    "body": "400 Bad Request: Missing Content-Length header"
+                }
+            try:
+                content_length = int(content_length)
+                if content_length > MAX_UPLOAD_SIZE:
+                    print(f"Upload rejected: Content-Length ({content_length}) exceeds {MAX_UPLOAD_SIZE} bytes")
+                    return {
+                        "status_code": 413,
+                        "body": "413 Payload Too Large: Uploaded file exceeds size limit."
+                    }
+            except ValueError:
                 return {
-                    "status_code": 413,
-                    "body": "413 Payload Too Large: Uploaded file exceeds size limit."
+                    "status_code": 400,
+                    "body": "400 Bad Request: Invalid Content-Length header"
                 }
-        # If the check passes, return None to continue processing.
+        # Continue processing if checks pass
         return None
 
     async def after_request(self, request, status_code, response_body, extra_headers):
@@ -28,7 +37,9 @@ class MaxUploadSizeMiddleware(HttpMiddleware):
 class FileUploadApp(App):
     async def index(self):
         """Serves an HTML form for file uploads."""
-        return """<html>
+        return (
+            200,
+            """<html>
 <head><title>File Upload</title></head>
 <body>
     <h2>Upload a File</h2>
@@ -37,11 +48,30 @@ class FileUploadApp(App):
         <input type="submit" value="Upload">
     </form>
 </body>
-</html>"""
+</html>""",
+            [("Content-Type", "text/html; charset=utf-8")]
+        )
 
     async def upload(self, file):
+        """Handles file uploads and processes the file content."""
         filename = file["filename"]
-        return filename
+        content_type = file["content_type"]
+        content_queue = file["content"]
+
+        # Process file content from the queue
+        total_size = 0
+        while True:
+            chunk = await content_queue.get()
+            if chunk is None:  # End of file
+                break
+            total_size += len(chunk)
+            # Example: Process chunk (e.g., save to disk, validate, etc.)
+            # For demonstration, just count the size
+        return {
+            "filename": filename,
+            "content_type": content_type,
+            "size": total_size
+        }
 
 
 app = FileUploadApp()
diff --git a/micropie.py b/micropie.py
index 051d39b..0cf76e8 100644
--- a/micropie.py
+++ b/micropie.py
@@ -427,7 +427,20 @@ class App:
         status_code: int = 200
         response_body: Any = ""
         extra_headers: List[Tuple[str, str]] = []
-        parse_task = None  # background multipart task if started
+        parse_task: Optional[asyncio.Task] = None  # background multipart task if started
+
+        async def _cancel_parse_task():
+            if parse_task is not None and not parse_task.done():
+                parse_task.cancel()
+                try:
+                    await parse_task
+                except asyncio.CancelledError:
+                    pass
+
+        async def _early_exit(code: int, body: Any, headers: Optional[List[Tuple[str, str]]] = None):
+            await _cancel_parse_task()
+            await self._send_response(send, code, body, headers or [])
+            return
 
         async def _await_file_param(name: str) -> Optional[Any]:
             """
@@ -457,12 +470,13 @@ class App:
                 if "multipart/form-data" in content_type:
                     if not MULTIPART_INSTALLED:
                         print("For multipart form data support install 'multipart'.")
-                        await self._send_response(send, 500, "500 Internal Server Error")
+                        await _early_exit(500, "500 Internal Server Error")
                         return
                     boundary_match = re.search(r"boundary=([^;]+)", content_type)
                     if not boundary_match:
-                        await self._send_response(send, 400, "400 Bad Request: Missing boundary")
+                        await _early_exit(400, "400 Bad Request: Missing boundary")
                         return
+                    # Start parsing in the background; do NOT await here so handlers/middleware can run concurrently.
                     parse_task = asyncio.create_task(
                         self._parse_multipart_into_request(
                             receive,
@@ -471,12 +485,6 @@ class App:
                             file_queue_maxsize=2048,
                         )
                     )
-                    try:
-                        await parse_task
-                        request.body_parsed = True
-                    except Exception as e:
-                        await self._send_response(send, 500, "500 Internal Server Error: Multipart parsing failed")
-                        return
                 else:
                     body_data = bytearray()
                     try:
@@ -488,7 +496,7 @@ class App:
                                 if not msg.get("more_body"):
                                     break
                     except asyncio.TimeoutError:
-                        await self._send_response(send, 408, "408 Request Timeout: Failed to receive body")
+                        await _early_exit(408, "408 Request Timeout: Failed to receive body")
                         return
                     decoded_body = body_data.decode("utf-8", "ignore")
                     if "application/json" in content_type:
@@ -496,8 +504,8 @@ class App:
                             request.get_json = json.loads(decoded_body)
                             if isinstance(request.get_json, dict):
                                 request.body_params = {k: [str(v)] for k, v in request.get_json.items()}
-                        except Exception as e:
-                            await self._send_response(send, 400, "400 Bad Request: Bad JSON")
+                        except Exception:
+                            await _early_exit(400, "400 Bad Request: Bad JSON")
                             return
                     else:
                         request.body_params = parse_qs(decoded_body)
@@ -511,11 +519,13 @@ class App:
                         result["body"],
                         result.get("headers", []),
                     )
-                    await self._send_response(send, status_code, response_body, extra_headers)
+                    await _early_exit(status_code, response_body, extra_headers)
                     return
 
             # Subapp handoff
             if hasattr(request, "_subapp"):
+                # If we started a multipart parse, cancel it before handing off
+                await _cancel_parse_task()
                 new_scope = dict(scope)
                 new_scope["path"] = request._subapp_path
                 new_scope["root_path"] = scope.get("root_path", "") + "/" + self.middlewares[0].mount_path
@@ -524,11 +534,12 @@ class App:
                 new_scope["get_json"] = getattr(request, "get_json", {})
                 new_scope["files"] = request.files
                 new_scope["session"] = request.session
-                # Create a receive callable that returns an empty body if already parsed
+
                 async def subapp_receive():
                     if request.body_parsed or request.body_params:
                         return {"type": "http.request", "body": b"", "more_body": False}
                     return await receive()
+
                 await request._subapp(new_scope, subapp_receive, send)
                 return
 
@@ -540,14 +551,14 @@ class App:
             else:
                 func_name: str = parts[0] if parts else "index"
                 if func_name.startswith("_") or func_name.startswith("ws_"):
-                    await self._send_response(send, 404, "404 Not Found")
+                    await _early_exit(404, "404 Not Found")
                     return
 
             if not request.path_params:
                 request.path_params = parts[1:] if len(parts) > 1 else []
             handler = getattr(self, func_name, None) or getattr(self, "index", None)
             if not handler:
-                await self._send_response(send, 404, "404 Not Found")
+                await _early_exit(404, "404 Not Found")
                 return
 
             # Initialize func_args early to avoid UnboundLocalError
@@ -561,7 +572,7 @@ class App:
                     for param in sig.parameters.values() if param.name != "self"
                 )
                 if not accepts_params:
-                    await self._send_response(send, 404, "404 Not Found")
+                    await _early_exit(404, "404 Not Found")
                     return
                 request.path_params = parts  # Pass all path parts to index handler
 
@@ -590,9 +601,7 @@ class App:
                 elif "multipart/form-data" in content_type:
                     param_value = await _await_file_param(param.name)
                     if param_value is None and param.default is param.empty:
-                        status_code = 400
-                        response_body = f"400 Bad Request: Missing required parameter '{param.name}'"
-                        await self._send_response(send, status_code, response_body)
+                        await _early_exit(400, f"400 Bad Request: Missing required parameter '{param.name}'")
                         return
                     if param_value is None:
                         param_value = param.default
@@ -601,9 +610,7 @@ class App:
                 elif param.default is not param.empty:
                     param_value = param.default
                 else:
-                    status_code = 400
-                    response_body = f"400 Bad Request: Missing required parameter '{param.name}'"
-                    await self._send_response(send, status_code, response_body)
+                    await _early_exit(400, f"400 Bad Request: Missing required parameter '{param.name}'")
                     return
 
                 func_args.append(param_value)
@@ -613,16 +620,16 @@ class App:
                 result = await handler(*func_args) if inspect.iscoroutinefunction(handler) else handler(*func_args)
             except Exception:
                 traceback.print_exc()
-                await self._send_response(send, 500, "500 Internal Server Error")
+                await _early_exit(500, "500 Internal Server Error")
                 return
 
             # Ensure background parser (if any) is finished before finalizing response
             if parse_task is not None:
                 try:
                     await parse_task
+                    request.body_parsed = True
                 except Exception:
                     traceback.print_exc()
-                    # Decide policy: you could 500 here; often handler already succeeded.
 
             # Normalize response
             if isinstance(result, tuple):
diff --git a/pyproject.toml b/pyproject.toml
index ba3fe3d..4a876c1 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"
 
 [project]
 name = "micropie"
-version = "0.22"
+version = "0.23"
 description = "An ultra micro ASGI web framework"
 keywords = ["micropie", "asgi", "microframework", "http"]
 readme = "docs/README.md"