Add websocket middleware, improve explicit routing examples

Commit 760bd4c · patx · 2025-06-24T01:23:11-04:00

Changeset
760bd4c05353c2897e65f15f289e4f5c9af43125
Parents
5142add6469449e989b72eb760c41dcade3ff363

View source at this commit

Comments

No comments yet.

Log in to comment

Diff

diff --git a/MicroPie.py b/MicroPie.py
index b26116b..55a26e6 100644
--- a/MicroPie.py
+++ b/MicroPie.py
@@ -256,13 +256,20 @@ class ConnectionClosed(Exception):
 # -----------------------------
 class HttpMiddleware(ABC):
     """
-    Pluggable middleware class that allows hooking into the request lifecycle.
+    Pluggable middleware class that allows hooking into the HTTP request lifecycle.
     """
 
     @abstractmethod
-    async def before_request(self, request: Request) -> None:
+    async def before_request(self, request: Request) -> Optional[Dict]:
         """
-        Called before the request is processed.
+        Called before the HTTP request is processed.
+        
+        Args:
+            request: The Request object.
+        
+        Returns:
+            Optional dictionary with response details (status_code, body, headers) to short-circuit the request,
+            or None to continue processing.
         """
         pass
 
@@ -273,11 +280,47 @@ class HttpMiddleware(ABC):
         status_code: int,
         response_body: Any,
         extra_headers: List[Tuple[str, str]]
-    ) -> None:
+    ) -> Optional[Dict]:
+        """
+        Called after the HTTP request is processed, but before the final response is sent.
+        
+        Args:
+            request: The Request object.
+            status_code: The HTTP status code.
+            response_body: The response body.
+            extra_headers: List of header tuples.
+        
+        Returns:
+            Optional dictionary with updated response details (status_code, body, headers), or None to use defaults.
+        """
+        pass
+
+class WebSocketMiddleware(ABC):
+    """
+    Pluggable middleware class that allows hooking into the WebSocket request lifecycle.
+    """
+
+    @abstractmethod
+    async def before_websocket(self, request: WebSocketRequest) -> Optional[Dict]:
         """
-        Called after the request is processed, but before the final response
-        is sent to the client. You may alter the status_code, response_body,
-        or extra_headers if needed.
+        Called before the WebSocket handler is invoked.
+        
+        Args:
+            request: The WebSocketRequest object.
+        
+        Returns:
+            Optional dictionary with close details (code, reason) to reject the connection,
+            or None to continue processing.
+        """
+        pass
+
+    @abstractmethod
+    async def after_websocket(self, request: WebSocketRequest) -> None:
+        """
+        Called after the WebSocket handler completes.
+        
+        Args:
+            request: The WebSocketRequest object.
         """
         pass
 
@@ -288,8 +331,8 @@ class HttpMiddleware(ABC):
 class App:
     """
     ASGI application for handling HTTP and WebSocket requests in MicroPie.
-    It supports pluggable session backends via the 'session_backend' attribute
-    and pluggable middlewares via the 'middlewares' list.
+    It supports pluggable session backends via the 'session_backend' attribute,
+    pluggable HTTP middlewares via the 'middlewares' list, and WebSocket middlewares via the 'ws_middlewares' list.
     """
 
     def __init__(self, session_backend: Optional[SessionBackend] = None) -> None:
@@ -303,6 +346,7 @@ class App:
             self.env = None
         self.session_backend: SessionBackend = session_backend or InMemorySessionBackend()
         self.middlewares: List[HttpMiddleware] = []
+        self.ws_middlewares: List[WebSocketMiddleware] = []
 
     @property
     def request(self) -> Request:
@@ -517,6 +561,13 @@ class App:
             cookies = self._parse_cookies(request.headers.get("cookie", ""))
             request.session = await self.session_backend.load(cookies.get("session_id", "")) or {}
 
+            # Run WebSocket middleware before_websocket
+            for mw in self.ws_middlewares:
+                if result := await mw.before_websocket(request):
+                    code, reason = result.get("code", 1008), result.get("reason", "Middleware rejected")
+                    await self._send_websocket_close(send, code, reason)
+                    return
+
             # Parse path and find handler
             path: str = scope["path"].lstrip("/")
             parts: List[str] = path.split("/") if path else []
@@ -527,6 +578,8 @@ class App:
 
             # Map WebSocket handler (e.g., /chat -> ws_chat)
             handler_name = f"ws_{func_name}" if func_name else "ws_index"
+            if hasattr(request, "_ws_route_handler"):
+                handler_name = request._ws_route_handler
             request.path_params = parts[1:] if len(parts) > 1 else []
             handler = getattr(self, handler_name, None)
             if not handler:
@@ -578,6 +631,10 @@ class App:
             if request.session:
                 await self.session_backend.save(ws.session_id, request.session, SESSION_TIMEOUT)
 
+            # Run WebSocket middleware after_websocket
+            for mw in self.ws_middlewares:
+                await mw.after_websocket(request)
+
         finally:
             current_request.reset(token)
 
diff --git a/README.md b/README.md
index 6207143..7a8b164 100644
--- a/README.md
+++ b/README.md
@@ -106,7 +106,7 @@ MicroPie's route handlers map URLs to methods in your `App` subclass, handling H
   - Sync/async generator for streaming.
 
 #### **Advanced Usage**
-- **Custom Routing**: Use middleware for explicit routing (see [examples/middleware](https://github.com/patx/micropie/tree/main/examples/middleware) and [examples/rest](https://github.com/patx/micropie/tree/main/examples/rest)).
+- **Custom Routing**: Use middleware for explicit routing (see [examples/middleware](https://github.com/patx/micropie/tree/main/examples/middleware) and [examples/explicit_routing](https://github.com/patx/micropie/tree/main/examples/explicit_routing)).
 - **Errors**: Auto-handled 404/400; customize via middleware.
 - **Dynamic Params**: Use `*args` for multiple path parameters.
 
@@ -139,7 +139,7 @@ class MyApp(App):
         return f"Submitted by: {username}"
 ```
 
-By default, MicroPie's route handlers can accept any request method, it's up to you how to handle any incoming requests! You can check the request method (and an number of other things specific to the current request state) in the handler with`self.request.method`. You can see how to handle POST JSON data at [examples/api](https://github.com/patx/micropie/tree/main/examples/api).
+By default, MicroPie's route handlers can accept any request method, it's up to you how to handle any incoming requests! You can check the request method (and an number of other things specific to the current request state) in the handler with`self.request.method`. You can see how to handle POST JSON data at [examples/api](https://github.com/patx/micropie/tree/main/examples/api) and [examples/json](https://github.com/patx/micropie/tree/main/examples/json).
 
 ### Real-Time Communication with WebSockets and Socket.IO
 MicroPie includes built-in support for WebSocket connections. WebSocket routes are defined in your App subclass using methods prefixed with `ws_`, mirroring the simplicity of MicroPie's HTTP routing. For example, a method named `ws_chat` handles WebSocket connections at `ws://<host>/chat`.
@@ -318,17 +318,27 @@ An in-memory implementation of the `SessionBackend`.
 
 ## Middleware Abstraction
 
-MicroPie allows you to create pluggable middleware to hook into the request lifecycle.
+MicroPie allows you to create pluggable middleware to hook into the request lifecycle for both HTTP and WebSocket requests.
 
 ### `HttpMiddleware` Class
 
 #### Methods
 
-- `before_request(request: Request) -> None`
-  - Abstract method called before the request is processed.
+- `before_request(request: Request) -> Optional[Dict]`
+  - Abstract method called before the HTTP request is processed. Returns an optional dictionary with response details (status_code, body, headers) to short-circuit the request, or None to continue processing.
 
-- `after_request(request: Request, status_code: int, response_body: Any, extra_headers: List[Tuple[str, str]]) -> None`
-  - Abstract method called after the request is processed but before the final response is sent to the client.
+- `after_request(request: Request, status_code: int, response_body: Any, extra_headers: List[Tuple[str, str]]) -> Optional[Dict]`
+  - Abstract method called after the HTTP request is processed but before the final response is sent. Returns an optional dictionary with updated response details (status_code, body, headers), or None to use defaults.
+
+### `WebSocketMiddleware` Class
+
+#### Methods
+
+- `before_websocket(request: WebSocketRequest) -> Optional[Dict]`
+  - Abstract method called before the WebSocket handler is invoked. Returns an optional dictionary with close details (code, reason) to reject the connection, or None to continue processing.
+
+- `after_websocket(request: WebSocketRequest) -> None`
+  - Abstract method called after the WebSocket handler completes.
 
 ## Request Objects
 
@@ -400,10 +410,15 @@ An exception raised when a WebSocket connection is closed.
 
 The main ASGI application class for handling HTTP and WebSocket requests in MicroPie.
 
+#### Attributes
+
+- `middlewares`: List of `HttpMiddleware` instances for HTTP request processing.
+- `ws_middlewares`: List of `WebSocketMiddleware` instances for WebSocket request processing.
+
 #### Methods
 
 - `__init__(session_backend: Optional[SessionBackend] = None) -> None`
-  - Initializes the application with an optional session backend.
+  - Initializes the application with an optional session backend and empty middleware lists for HTTP and WebSocket requests.
 
 - `request -> Request`
   - Retrieves the current request from the context variable.
diff --git a/examples/api/simple.py b/examples/api/simple.py
index ea2b9f9..8332523 100644
--- a/examples/api/simple.py
+++ b/examples/api/simple.py
@@ -4,10 +4,13 @@ from MicroPie import App
 class Root(App):
 
     async def index(self, id, name, age):
-        if self.request.method == 'POST':
-            return {'id': id,'name': name,'age': age}
+        if self.request.method == "POST":
+            return {"id": id,"name": name,"age": age}
 
     async def echo(self):
+        if self.request.method == "GET":
+            return {"input": False, "extra": False}
+
         data = self.request.get_json
         return {"input": data, "extra": True}
 
@@ -15,6 +18,6 @@ class Root(App):
         return ["a", "b"]
 
     async def html(self):
-        return 'Hello world'
+        return "<b>Hello world</b>"
 
 app = Root()
diff --git a/examples/rest/app.py b/examples/explicit_routing/http.py
similarity index 94%
rename from examples/rest/app.py
rename to examples/explicit_routing/http.py
index 3030b72..25a39a0 100644
--- a/examples/rest/app.py
+++ b/examples/explicit_routing/http.py
@@ -1,6 +1,6 @@
-from micropie_rest import RESTApp, route
+from micropie_routing import ExplicitApp, route
 
-class MyApp(RESTApp):
+class MyApp(ExplicitApp):
 
     @route("/api/users/{user:str}/records/{record:int}", method=["GET", "HEAD"])
     async def _get_record(self, user: str, record: int):
diff --git a/examples/rest/micropie_rest.py b/examples/explicit_routing/micropie_routing.py
similarity index 51%
rename from examples/rest/micropie_rest.py
rename to examples/explicit_routing/micropie_routing.py
index 8050a18..33d9c62 100644
--- a/examples/rest/micropie_rest.py
+++ b/examples/explicit_routing/micropie_routing.py
@@ -1,6 +1,6 @@
 import re
 from typing import Dict, List, Optional, Tuple, Any, Callable, Type, Union
-from MicroPie import App, HttpMiddleware, Request
+from MicroPie import App, HttpMiddleware, WebSocketMiddleware, Request, WebSocketRequest
 
 class RouteError(Exception):
     """Custom exception for route-related errors."""
@@ -51,7 +51,7 @@ class ExplicitRouter(HttpMiddleware):
             request: The MicroPie Request object
         
         Returns:
-            None to let MicroPie handle parsing and routing
+            Dictionary with response details to short-circuit, or None to continue.
         """
         path = request.scope["path"]
         
@@ -69,7 +69,7 @@ class ExplicitRouter(HttpMiddleware):
                     request._route_handler = handler_name  # Set handler name as string
                     return None
                 except ValueError as e:
-                    return {"error": f"Invalid parameter format: {str(e)}"}
+                    return {"status_code": 400, "body": f"Invalid parameter format: {str(e)}"}
         
         return None
     
@@ -82,8 +82,72 @@ class ExplicitRouter(HttpMiddleware):
     ) -> Optional[Dict]:
         return None
 
+class WebSocketExplicitRouter(WebSocketMiddleware):
+    def __init__(self):
+        # Map WebSocket route paths to (regex pattern, handler_name, param_types)
+        self.routes: Dict[str, Tuple[str, str, List[Type]]] = {}
+    
+    def add_route(self, path: str, handler: Callable) -> None:
+        """
+        Register an explicit WebSocket route with its handler.
+        
+        Args:
+            path: The route pattern (e.g., "/ws/users/{user:str}/chat")
+            handler: The handler function
+        """
+        # Parse parameter types from path
+        param_types = []
+        pattern = re.sub(r"{([^:]+):([^}]+)}", lambda m: self._process_param(m, param_types), path)
+        pattern = f"^{pattern}$"
+        # Store handler name instead of handler function
+        self.routes[path] = (pattern, handler.__name__, param_types)
+    
+    def _process_param(self, match: re.Match, param_types: List[Type]) -> str:
+        """Process a route parameter and store its type."""
+        param_name, param_type = match.group(1), match.group(2)
+        if param_type == "int":
+            param_types.append(int)
+            return r"(\d+)"
+        elif param_type == "str":
+            param_types.append(str)
+            return r"([^/]+)"
+        else:
+            raise RouteError(f"Unsupported parameter type: {param_type}")
+    
+    async def before_websocket(self, request: WebSocketRequest) -> Optional[Dict]:
+        """
+        Match the WebSocket path and set path parameters for routing.
+        
+        Args:
+            request: The WebSocketRequest object
+        
+        Returns:
+            Dictionary with close details to reject, or None to continue.
+        """
+        path = request.scope["path"]
+        
+        for route_path, (pattern, handler_name, param_types) in self.routes.items():
+            match = re.match(pattern, path)
+            if match:
+                try:
+                    # Convert parameters to their specified types
+                    params = [
+                        param_type(param) for param, param_type in zip(match.groups(), param_types)
+                    ]
+                    request.path_params = params
+                    request._ws_route_handler = handler_name  # Set handler name as string
+                    return None
+                except ValueError as e:
+                    return {"code": 1008, "reason": f"Invalid parameter format: {str(e)}"}
+        
+        return None
+    
+    async def after_websocket(self, request: WebSocketRequest) -> None:
+        """Post-processing after WebSocket handler execution."""
+        pass
+
 def route(path: str, method: Union[str, List[str]] = "GET"):
-    """Decorator to register a route for a handler method."""
+    """Decorator to register a route for an HTTP handler method."""
     def decorator(handler: Callable) -> Callable:
         # Normalize method to a list
         methods = [method] if isinstance(method, str) else method
@@ -91,17 +155,29 @@ def route(path: str, method: Union[str, List[str]] = "GET"):
         return handler
     return decorator
 
-class RESTApp(App):
-    """A subclass of MicroPie.App that automatically registers routes using ExplicitRouter."""
+def ws_route(path: str):
+    """Decorator to register a route for a WebSocket handler method."""
+    def decorator(handler: Callable) -> Callable:
+        handler._ws_route = path
+        return handler
+    return decorator
+
+class ExplicitApp(App):
+    """A subclass of MicroPie.App that automatically registers HTTP and WebSocket routes."""
     def __init__(self):
         super().__init__()
         self.router = ExplicitRouter()
+        self.ws_router = WebSocketExplicitRouter()
         self.middlewares.append(self.router)
+        self.ws_middlewares.append(self.ws_router)
         self._register_routes()
     
     def _register_routes(self):
-        """Automatically register routes from decorated methods."""
+        """Automatically register HTTP and WebSocket routes from decorated methods."""
         for name, method in self.__class__.__dict__.items():
             if hasattr(method, "_route"):
                 path, methods = method._route
                 self.router.add_route(path, getattr(self, name), methods)
+            if hasattr(method, "_ws_route"):
+                path = method._ws_route
+                self.ws_router.add_route(path, getattr(self, name))
diff --git a/examples/explicit_routing/test.html b/examples/explicit_routing/test.html
new file mode 100644
index 0000000..1bb63b4
--- /dev/null
+++ b/examples/explicit_routing/test.html
@@ -0,0 +1,25 @@
+<!DOCTYPE html>
+<html>
+<head>
+    <title>WebSocket Test</title>
+</head>
+<body>
+    <input id="message" type="text">
+    <button onclick="sendMessage()">Send</button>
+    <div id="output"></div>
+    <script>
+        const ws = new WebSocket("ws://localhost:8000/chat/myroom?user=Alice");
+        ws.onmessage = function(event) {
+            document.getElementById("output").innerText += event.data + "\n";
+        };
+        ws.onclose = function(event) {
+            document.getElementById("output").innerText += `Closed: ${event.code} ${event.reason}\n`;
+        };
+        function sendMessage() {
+            const msg = document.getElementById("message").value;
+            ws.send(msg);
+            document.getElementById("message").value = "";
+        }
+    </script>
+</body>
+</html>
diff --git a/examples/explicit_routing/ws.py b/examples/explicit_routing/ws.py
new file mode 100644
index 0000000..496849a
--- /dev/null
+++ b/examples/explicit_routing/ws.py
@@ -0,0 +1,22 @@
+from micropie_routing import ExplicitApp, route, ws_route
+from MicroPie import WebSocket, ConnectionClosed
+
+class MyApp(ExplicitApp):
+    @route("/api/users/{user_id:int}", method=["GET"])
+    async def get_user(self, user_id: int):
+        return f"User ID: {user_id}"
+
+    @ws_route("/ws/chat/{room:str}")
+    async def ws_chat(self, ws: WebSocket, room: str):
+        await ws.accept()
+        user = self.request.query_params.get("user", ["anonymous"])[0]
+        self.request.session["last_room"] = room
+        while True:
+            try:
+                message = await ws.receive_text()
+                response = f"{user} ({room}): {message}"
+                await ws.send_text(response)
+            except ConnectionClosed:
+                break
+
+app = MyApp()
diff --git a/examples/websockets/test.html b/examples/websockets/test.html
index 46e8770..1bb63b4 100644
--- a/examples/websockets/test.html
+++ b/examples/websockets/test.html
@@ -8,7 +8,7 @@
     <button onclick="sendMessage()">Send</button>
     <div id="output"></div>
     <script>
-        const ws = new WebSocket("ws://localhost:8000/chat?user=Alice");
+        const ws = new WebSocket("ws://localhost:8000/chat/myroom?user=Alice");
         ws.onmessage = function(event) {
             document.getElementById("output").innerText += event.data + "\n";
         };