patx/micropie
ruff format all code
Commit c9bb7c4 · patx · 2026-01-11T20:02:09-05:00
Comments
No comments yet.
Diff
diff --git a/docs/apidocs/conf.py b/docs/apidocs/conf.py
index f73ad68..a8d0f78 100644
--- a/docs/apidocs/conf.py
+++ b/docs/apidocs/conf.py
@@ -37,4 +37,3 @@ html_theme_options = {
pygments_style = "friendly"
html_static_path = ["_static"]
-
diff --git a/examples/auth/app.py b/examples/auth/app.py
index 1023a5e..f64debb 100644
--- a/examples/auth/app.py
+++ b/examples/auth/app.py
@@ -11,12 +11,15 @@ GITHUB_AUTH_URL = "https://github.com/login/oauth/authorize"
GITHUB_TOKEN_URL = "https://github.com/login/oauth/access_token"
GITHUB_API_URL = "https://api.github.com/user"
+
class Root(App):
async def index(self):
return '<a href="/login">Login with GitHub</a>'
async def login(self):
- return self._redirect(f"{GITHUB_AUTH_URL}?client_id={CLIENT_ID}&redirect_uri={REDIRECT_URI}")
+ return self._redirect(
+ f"{GITHUB_AUTH_URL}?client_id={CLIENT_ID}&redirect_uri={REDIRECT_URI}"
+ )
async def callback(self):
code = self.request.query_params.get("code")
diff --git a/examples/blog/app.py b/examples/blog/app.py
index 702719e..0f731af 100644
--- a/examples/blog/app.py
+++ b/examples/blog/app.py
@@ -17,6 +17,7 @@ COLLECTION_USERS = "users"
USERNAME = "demo"
FULLNAME = "John Smith"
+
def serialize_post(doc: Dict[str, Any]) -> Dict[str, Any]:
"""Convert a Mongo document into a JSON-friendly dict."""
created_at = doc.get("created_at")
@@ -36,6 +37,7 @@ def serialize_post(doc: Dict[str, Any]) -> Dict[str, Any]:
# ---------- Startup / shutdown ----------
+
async def init_db():
try:
print("[init_db] starting init")
@@ -53,18 +55,18 @@ async def init_db():
existing = await app.users.find_one({"username": USERNAME})
if not existing:
- await app.users.insert_one(
- {"username": "demo", "password": "demo"}
- )
+ await app.users.insert_one({"username": "demo", "password": "demo"})
print("[init_db] finished without error")
except Exception as e:
import traceback
+
print("[init_db] ERROR!", repr(e))
traceback.print_exc()
raise
+
async def close_db():
"""
ASGI shutdown handler: close Mongo client.
@@ -238,7 +240,6 @@ class BlogApp(App):
self.request.session.clear()
return self._redirect("/")
-
# ---------- JSON API HANDLERS ----------
async def api_posts(self):
@@ -356,18 +357,13 @@ class BlogApp(App):
return 405, {"error": "Method not allowed on /api_post/<id>."}
-
-app = BlogApp(session_backend=MkvSessionBackend(
- mongo_uri=MONGO_URI,
- db_name=DB_NAME
- )
-)
+app = BlogApp(session_backend=MkvSessionBackend(mongo_uri=MONGO_URI, db_name=DB_NAME))
app.middlewares.append(
MongoRateLimitMiddleware(
mongo_uri=MONGO_URI,
db_name=DB_NAME,
- allowed_hosts=None, # don't enforce host allowlist, change in prod
- trust_proxy_headers=False, # change in prod
+ allowed_hosts=None, # don't enforce host allowlist, change in prod
+ trust_proxy_headers=False, # change in prod
require_cf_ray=False,
)
)
diff --git a/examples/blog/middlewares/rate_limit.py b/examples/blog/middlewares/rate_limit.py
index 29b93fd..8adfa05 100644
--- a/examples/blog/middlewares/rate_limit.py
+++ b/examples/blog/middlewares/rate_limit.py
@@ -126,11 +126,12 @@ class MongoRateLimitMiddleware(HttpMiddleware):
"violations": {"$ifNull": ["$violations", 0]},
"blocked_until": {"$ifNull": ["$blocked_until", None]},
"permanent_blocked": {"$ifNull": ["$permanent_blocked", False]},
- "permanent_blocked_at": {"$ifNull": ["$permanent_blocked_at", None]},
+ "permanent_blocked_at": {
+ "$ifNull": ["$permanent_blocked_at", None]
+ },
"violation_events": {"$ifNull": ["$violation_events", []]},
}
},
-
# 2) Prune old violation events
{
"$set": {
@@ -143,7 +144,6 @@ class MongoRateLimitMiddleware(HttpMiddleware):
}
}
},
-
# 3) Are we currently blocked?
{
"$set": {
@@ -160,7 +160,6 @@ class MongoRateLimitMiddleware(HttpMiddleware):
}
}
},
-
# 4) Update window/count atomically (only if not blocked)
{
"$set": {
@@ -186,12 +185,17 @@ class MongoRateLimitMiddleware(HttpMiddleware):
"$cond": [
"$_blocked_now",
"$count",
- {"$cond": ["$_window_expired", 1, {"$add": ["$count", 1]}]},
+ {
+ "$cond": [
+ "$_window_expired",
+ 1,
+ {"$add": ["$count", 1]},
+ ]
+ },
]
},
}
},
-
# 5) Over limit?
{
"$set": {
@@ -203,12 +207,15 @@ class MongoRateLimitMiddleware(HttpMiddleware):
}
}
},
-
# 6) Record violation if over limit
{
"$set": {
"violations": {
- "$cond": ["$_over_limit", {"$add": ["$violations", 1]}, "$violations"]
+ "$cond": [
+ "$_over_limit",
+ {"$add": ["$violations", 1]},
+ "$violations",
+ ]
},
"violation_events": {
"$cond": [
@@ -219,7 +226,6 @@ class MongoRateLimitMiddleware(HttpMiddleware):
},
}
},
-
# 7) Temporary block escalation
{
"$set": {
@@ -228,7 +234,12 @@ class MongoRateLimitMiddleware(HttpMiddleware):
{
"$and": [
"$_over_limit",
- {"$gte": ["$violations", self.BLOCK_AFTER_VIOLATIONS]},
+ {
+ "$gte": [
+ "$violations",
+ self.BLOCK_AFTER_VIOLATIONS,
+ ]
+ },
]
},
now + timedelta(seconds=self.BLOCK_FOR_SECONDS),
@@ -237,7 +248,6 @@ class MongoRateLimitMiddleware(HttpMiddleware):
}
}
},
-
# 8) Permanent block escalation
{"$set": {"_events_24h": {"$size": "$violation_events"}}},
{
@@ -247,7 +257,12 @@ class MongoRateLimitMiddleware(HttpMiddleware):
{
"$and": [
"$_over_limit",
- {"$gte": ["$_events_24h", self.PERMA_BLOCK_AFTER]},
+ {
+ "$gte": [
+ "$_events_24h",
+ self.PERMA_BLOCK_AFTER,
+ ]
+ },
]
},
True,
@@ -259,7 +274,12 @@ class MongoRateLimitMiddleware(HttpMiddleware):
{
"$and": [
"$_over_limit",
- {"$gte": ["$_events_24h", self.PERMA_BLOCK_AFTER]},
+ {
+ "$gte": [
+ "$_events_24h",
+ self.PERMA_BLOCK_AFTER,
+ ]
+ },
{"$eq": ["$permanent_blocked_at", None]},
]
},
@@ -269,9 +289,15 @@ class MongoRateLimitMiddleware(HttpMiddleware):
},
}
},
-
# 9) Cleanup temp fields
- {"$unset": ["_blocked_now", "_window_expired", "_over_limit", "_events_24h"]},
+ {
+ "$unset": [
+ "_blocked_now",
+ "_window_expired",
+ "_over_limit",
+ "_events_24h",
+ ]
+ },
],
upsert=True,
return_document=ReturnDocument.AFTER,
@@ -308,4 +334,3 @@ class MongoRateLimitMiddleware(HttpMiddleware):
async def after_request(self, request, status_code, response_body, extra_headers):
return None
-
diff --git a/examples/blog/sessions/mongo_session.py b/examples/blog/sessions/mongo_session.py
index a97e863..8504461 100644
--- a/examples/blog/sessions/mongo_session.py
+++ b/examples/blog/sessions/mongo_session.py
@@ -91,4 +91,3 @@ class MkvSessionBackend(SessionBackend):
}
await self.store.set(key, payload)
-
diff --git a/examples/explicit_routing/app.py b/examples/explicit_routing/app.py
index 25a39a0..d71415d 100644
--- a/examples/explicit_routing/app.py
+++ b/examples/explicit_routing/app.py
@@ -1,11 +1,11 @@
from micropie_routing import ExplicitApp, route
-class MyApp(ExplicitApp):
+class MyApp(ExplicitApp):
@route("/api/users/{user:str}/records/{record:int}", method=["GET", "HEAD"])
async def _get_record(self, user: str, record: int):
return {"user": user, "record": record}
-
+
@route("/api/users/{user:str}/records", method=["POST"])
async def _create_record(self, user: str):
try:
@@ -13,11 +13,13 @@ class MyApp(ExplicitApp):
return {"user": user, "record": data.get("record_id"), "created": True}
except Exception:
return {"error": f"Invalid JSON"}
-
- @route("/api/users/{user:str}/records/{record:int}/details/subdetails", method="GET")
+
+ @route(
+ "/api/users/{user:str}/records/{record:int}/details/subdetails", method="GET"
+ )
async def _get_record_subdetails(self, user: str, record: int):
return {"user": user, "record": record, "subdetails": "more detailed info"}
-
+
# Implicitly routed (not using decorator)
async def records(self, user: str, record: str):
try:
@@ -25,9 +27,10 @@ class MyApp(ExplicitApp):
return {"user": user, "record": record_id, "implicit": True}
except ValueError:
return {"error": "Record must be an integer"}
-
+
# Private route, not exposed
async def _private(self):
return {"viewing": "private"}
+
app = MyApp()
diff --git a/examples/explicit_routing/micropie_routing.py b/examples/explicit_routing/micropie_routing.py
index 5b4fcb3..ce7aae3 100644
--- a/examples/explicit_routing/micropie_routing.py
+++ b/examples/explicit_routing/micropie_routing.py
@@ -3,25 +3,35 @@ from typing import Dict, List, Optional, Tuple, Any, Callable, Type, Union
import uuid
from micropie import App, HttpMiddleware, WebSocketMiddleware, Request, WebSocketRequest
+
# Specific exceptions for better error handling
class RouteError(Exception):
"""Base exception for routing errors."""
+
pass
+
class InvalidPathError(RouteError):
"""Raised for invalid route path formats."""
+
pass
+
class UnsupportedTypeError(RouteError):
"""Raised for unsupported parameter types."""
+
pass
+
class InvalidMethodError(RouteError):
"""Raised for invalid HTTP methods."""
+
pass
+
class BaseRouter:
"""Base class for HTTP and WebSocket routers to share logic."""
+
def __init__(self):
# Map route paths to (methods/subprotocol, compiled regex, handler, param_types)
self.routes: Dict[str, Tuple[Any, re.Pattern, Callable, List[Type]]] = {}
@@ -29,7 +39,10 @@ class BaseRouter:
"int": (int, r"(\d+)"),
"str": (str, r"([^/]+)"),
"float": (float, r"([-+]?\d*\.?\d+)"),
- "uuid": (uuid.UUID, r"([0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})")
+ "uuid": (
+ uuid.UUID,
+ r"([0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})",
+ ),
}
def _process_param(self, match: re.Match, param_types: List[Type]) -> str:
@@ -42,10 +55,12 @@ class BaseRouter:
param_types.append(param_type)
return regex
- def add_route(self, path: str, handler: Callable, methods_or_subprotocol: Any) -> None:
+ def add_route(
+ self, path: str, handler: Callable, methods_or_subprotocol: Any
+ ) -> None:
"""
Register a route with its handler and methods/subprotocol.
-
+
Args:
path: The route pattern (e.g., "/users/{user}/{record:int}")
handler: The handler function
@@ -55,27 +70,42 @@ class BaseRouter:
raise InvalidPathError(f"Route path must start with '/': {path}")
param_types = []
# Support both {name} and {name:type} syntax
- pattern = re.sub(r"{([^:}]*)?(?::([^}]+))?}", lambda m: self._process_param(m, param_types), path)
+ pattern = re.sub(
+ r"{([^:}]*)?(?::([^}]+))?}",
+ lambda m: self._process_param(m, param_types),
+ path,
+ )
try:
compiled_pattern = re.compile(f"^{pattern}$")
except re.error as e:
raise InvalidPathError(f"Invalid route pattern: {path} ({str(e)})")
- self.routes[path] = (methods_or_subprotocol, compiled_pattern, handler, param_types)
+ self.routes[path] = (
+ methods_or_subprotocol,
+ compiled_pattern,
+ handler,
+ param_types,
+ )
def list_routes(self) -> List[Dict[str, Any]]:
"""
Return a list of registered routes for debugging.
-
+
Returns:
List of dictionaries containing route details.
"""
return [
- {"path": path, "methods_or_subprotocol": details[0], "handler": details[2].__name__}
+ {
+ "path": path,
+ "methods_or_subprotocol": details[0],
+ "handler": details[2].__name__,
+ }
for path, details in self.routes.items()
]
+
class ExplicitRouter(HttpMiddleware):
"""Middleware for explicit HTTP routing with type-safe parameters."""
+
def __init__(self):
super().__init__()
self.router = BaseRouter()
@@ -83,7 +113,7 @@ class ExplicitRouter(HttpMiddleware):
def add_route(self, path: str, handler: Callable, methods: List[str]) -> None:
"""
Register an explicit HTTP route.
-
+
Args:
path: The route pattern (e.g., "/users/{user}/{record:int}")
handler: The handler function
@@ -94,34 +124,45 @@ class ExplicitRouter(HttpMiddleware):
valid_methods = {"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"}
methods = [m.upper() for m in methods]
if not all(m in valid_methods for m in methods):
- raise InvalidMethodError(f"Invalid HTTP methods: {', '.join(set(methods) - valid_methods)}")
+ raise InvalidMethodError(
+ f"Invalid HTTP methods: {', '.join(set(methods) - valid_methods)}"
+ )
self.router.add_route(path, handler, methods)
async def before_request(self, request: Request) -> Optional[Dict]:
"""
Match the request path and set path parameters.
-
+
Args:
request: The MicroPie Request object
-
+
Returns:
Dictionary with response details to short-circuit, or None to continue.
"""
path = request.scope["path"]
- for route_path, (methods, pattern, handler, param_types) in self.router.routes.items():
+ for route_path, (
+ methods,
+ pattern,
+ handler,
+ param_types,
+ ) in self.router.routes.items():
if request.method not in methods:
continue
match = pattern.match(path)
if match:
try:
params = [
- param_type(param) for param, param_type in zip(match.groups(), param_types)
+ param_type(param)
+ for param, param_type in zip(match.groups(), param_types)
]
request.path_params = params
request._route_handler = handler.__name__
return None
except ValueError as e:
- return {"status_code": 400, "body": f"Invalid parameter format: {str(e)}"}
+ return {
+ "status_code": 400,
+ "body": f"Invalid parameter format: {str(e)}",
+ }
return None
async def after_request(
@@ -129,20 +170,24 @@ class ExplicitRouter(HttpMiddleware):
request: Request,
status_code: int,
response_body: Any,
- extra_headers: List[Tuple[str, str]]
+ extra_headers: List[Tuple[str, str]],
) -> Optional[Dict]:
return None
+
class WebSocketExplicitRouter(WebSocketMiddleware):
"""Middleware for explicit WebSocket routing with type-safe parameters."""
+
def __init__(self):
super().__init__()
self.router = BaseRouter()
- def add_route(self, path: str, handler: Callable, subprotocol: Optional[str] = None) -> None:
+ def add_route(
+ self, path: str, handler: Callable, subprotocol: Optional[str] = None
+ ) -> None:
"""
Register an explicit WebSocket route.
-
+
Args:
path: The route pattern (e.g., "/ws/users/{user}/chat")
handler: The handler function
@@ -153,77 +198,96 @@ class WebSocketExplicitRouter(WebSocketMiddleware):
async def before_websocket(self, request: WebSocketRequest) -> Optional[Dict]:
"""
Match the WebSocket path and set path parameters.
-
+
Args:
request: The WebSocketRequest object
-
+
Returns:
Dictionary with close details to reject, or None to continue.
"""
path = request.scope["path"]
- for route_path, (subprotocol, pattern, handler, param_types) in self.router.routes.items():
+ for route_path, (
+ subprotocol,
+ pattern,
+ handler,
+ param_types,
+ ) in self.router.routes.items():
match = pattern.match(path)
if match:
try:
params = [
- param_type(param) for param, param_type in zip(match.groups(), param_types)
+ param_type(param)
+ for param, param_type in zip(match.groups(), param_types)
]
request.path_params = params
request._ws_route_handler = handler.__name__
request._ws_subprotocol = subprotocol
return None
except ValueError as e:
- return {"code": 1008, "reason": f"Invalid parameter format: {str(e)}"}
+ return {
+ "code": 1008,
+ "reason": f"Invalid parameter format: {str(e)}",
+ }
return None
async def after_websocket(self, request: WebSocketRequest) -> None:
"""Log WebSocket session closure for debugging."""
print(f"WebSocket session closed for path: {request.scope['path']}")
+
def route(path: str, method: Union[str, List[str]] = "GET"):
"""
Decorator to register an HTTP route with validation.
-
+
Args:
path: The route path (e.g., "/users/{user}")
method: HTTP method(s) as a string or list
-
+
Raises:
InvalidMethodError: If invalid HTTP methods are provided
InvalidPathError: If the path format is invalid
"""
+
def decorator(handler: Callable) -> Callable:
methods = [method] if isinstance(method, str) else method
valid_methods = {"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"}
methods = [m.upper() for m in methods]
if not all(m in valid_methods for m in methods):
- raise InvalidMethodError(f"Invalid HTTP methods: {', '.join(set(methods) - valid_methods)}")
+ raise InvalidMethodError(
+ f"Invalid HTTP methods: {', '.join(set(methods) - valid_methods)}"
+ )
if not path.startswith("/"):
raise InvalidPathError(f"Route path must start with '/': {path}")
handler._route = (path, methods)
return handler
+
return decorator
+
def ws_route(path: str, subprotocol: Optional[str] = None):
"""
Decorator to register a WebSocket route with optional subprotocol.
-
+
Args:
path: The WebSocket route path (e.g., "/ws/chat/{room}")
subprotocol: Optional WebSocket subprotocol
-
+
Raises:
InvalidPathError: If the path format is invalid
"""
+
def decorator(handler: Callable) -> Callable:
if not path.startswith("/"):
raise InvalidPathError(f"Route path must start with '/': {path}")
handler._ws_route = (path, subprotocol)
return handler
+
return decorator
+
class ExplicitApp(App):
"""A MicroPie App subclass with explicit routing support."""
+
def __init__(self, session_backend=None):
super().__init__(session_backend=session_backend)
self.router = ExplicitRouter()
@@ -245,11 +309,11 @@ class ExplicitApp(App):
def list_routes(self) -> Dict[str, List[Dict[str, Any]]]:
"""
Return all registered HTTP and WebSocket routes for debugging.
-
+
Returns:
Dictionary with 'http' and 'websocket' keys containing route details.
"""
return {
"http": self.router.router.list_routes(),
- "websocket": self.ws_router.router.list_routes()
+ "websocket": self.ws_router.router.list_routes(),
}
diff --git a/examples/explicit_routing/ws.py b/examples/explicit_routing/ws.py
index 6e2e162..ba53876 100644
--- a/examples/explicit_routing/ws.py
+++ b/examples/explicit_routing/ws.py
@@ -1,6 +1,7 @@
from micropie_routing import ExplicitApp, route, ws_route
from micropie import WebSocket, ConnectionClosed
+
class MyApp(ExplicitApp):
@route("/api/users/{user_id:int}", method=["GET"])
async def get_user(self, user_id: int):
@@ -19,4 +20,5 @@ class MyApp(ExplicitApp):
except ConnectionClosed:
break
+
app = MyApp()
diff --git a/examples/file_uploads/app.py b/examples/file_uploads/app.py
index 3343056..0106b87 100644
--- a/examples/file_uploads/app.py
+++ b/examples/file_uploads/app.py
@@ -1,10 +1,11 @@
-import os # Used for file path handling and directory creation
-import aiofiles # Asynchronous file I/O operations
+import os # Used for file path handling and directory creation
+import aiofiles # Asynchronous file I/O operations
from micropie import App # Import the base App class from MicroPie
# Ensure the "uploads" directory exists; create it if it doesn't
os.makedirs("uploads", exist_ok=True)
+
class Root(App):
"""
This is the main application class that inherits from MicroPie's App.
@@ -13,7 +14,7 @@ class Root(App):
async def index(self):
"""
- Serve a simple HTML form that lets the user choose a
+ Serve a simple HTML form that lets the user choose a
file and submit it via POST to /upload.
"""
return """<form action="/upload" method="post" enctype="multipart/form-data">
@@ -26,18 +27,18 @@ class Root(App):
Handle the uploaded file from the client:
- Saves the file to disk in the "uploads" directory.
- Uses aiofiles to write the file asynchronously, in chunks.
-
+
`file` is a dictionary with:
'filename': The original filename of the uploaded file.
- 'content_type': The MIME type of the file (defaults
+ 'content_type': The MIME type of the file (defaults
to application/octet-stream).
- 'content': An asyncio.Queue containing chunks of file data as
+ 'content': An asyncio.Queue containing chunks of file data as
bytes, with a None sentinel signaling the end of the stream.
"""
# Construct a safe path to save the uploaded file
filepath = os.path.join("uploads", file["filename"])
-
+
# Open the destination file asynchronously for writing
async with aiofiles.open(filepath, "wb") as f:
# Read and write the file in chunks
@@ -48,6 +49,6 @@ class Root(App):
# Return a confirmation response with the uploaded filename
return 200, f"Uploaded {file['filename']}"
+
# Instantiate the app
app = Root()
-
diff --git a/examples/headers/app.py b/examples/headers/app.py
index 2fa0456..9940439 100644
--- a/examples/headers/app.py
+++ b/examples/headers/app.py
@@ -9,10 +9,9 @@ class Root(App):
("X-Frame-Options", "DENY"),
("X-XSS-Protection", "1; mode=block"),
("Strict-Transport-Security", "max-age=31536000; includeSubDomains"),
- ("Content-Security-Policy", "default-src 'self'")
+ ("Content-Security-Policy", "default-src 'self'"),
]
return 200, "<b>hello world</b>", headers
-
app = Root()
diff --git a/examples/hello_world/app.py b/examples/hello_world/app.py
index a598eb7..7e6cb52 100644
--- a/examples/hello_world/app.py
+++ b/examples/hello_world/app.py
@@ -2,13 +2,13 @@ from micropie import App
class Root(App):
-
async def index(self):
- return 'Hello ASGI World!'
+ return "Hello ASGI World!"
- async def greet(self, first_name='World', last_name=None):
+ async def greet(self, first_name="World", last_name=None):
if last_name:
- return f'Hello {first_name} {last_name}'
- return f'Hello {first_name}'
+ return f"Hello {first_name} {last_name}"
+ return f"Hello {first_name}"
+
-app = Root() # Run with `uvicorn app:app`
+app = Root() # Run with `uvicorn app:app`
diff --git a/examples/json_api/app.py b/examples/json_api/app.py
index e16e478..8b390e1 100644
--- a/examples/json_api/app.py
+++ b/examples/json_api/app.py
@@ -2,10 +2,9 @@ from micropie import App
class Root(App):
-
async def index(self, id, name, age):
if self.request.method == "POST":
- return {"id": id,"name": name,"age": age}
+ return {"id": id, "name": name, "age": age}
async def echo(self):
if self.request.method == "GET":
@@ -20,4 +19,5 @@ class Root(App):
async def html(self):
return "<b>Hello world</b>"
+
app = Root()
diff --git a/examples/json_api/basic.py b/examples/json_api/basic.py
index 73f287f..0f74420 100644
--- a/examples/json_api/basic.py
+++ b/examples/json_api/basic.py
@@ -2,57 +2,46 @@ from micropie import App
from pickledb import AsyncPickleDB
from uuid import uuid4
-db = AsyncPickleDB('pastes.db')
+db = AsyncPickleDB("pastes.db")
class PasteApp(App):
-
async def paste(self, pid: str = None):
if self.request.method == "POST":
# Get content from JSON or form, depending on the Content-Type
- content = self.request.body_params.get('content')[0]
+ content = self.request.body_params.get("content")[0]
pid = str(uuid4())
await db.aset(pid, content)
return {
"status": "success",
"action": "post",
"paste_id": pid,
- "content": content
+ "content": content,
}
elif self.request.method == "DELETE":
await db.aremove(pid)
- return {
- "status": "success",
- "action": "delete",
- "paste_id": pid
- }
+ return {"status": "success", "action": "delete", "paste_id": pid}
elif self.request.method == "GET":
if pid:
paste = await db.aget(pid)
if paste is None:
- return 404, orjson.dumps({
- "status": "fail",
- "error": "Paste not found"
- })
+ return 404, orjson.dumps(
+ {"status": "fail", "error": "Paste not found"}
+ )
return {
"status": "success",
"action": "get",
"paste_id": pid,
- "content": paste
+ "content": paste,
}
all_keys = await db.aall()
- all_pastes = [{
- "paste_id": key,
- "content": await db.aget(key)
- } for key in all_keys]
- return 302, {
- "status": "success",
- "action": "get all",
- "pastes": all_pastes
- }
+ all_pastes = [
+ {"paste_id": key, "content": await db.aget(key)} for key in all_keys
+ ]
+ return 302, {"status": "success", "action": "get all", "pastes": all_pastes}
app = PasteApp()
diff --git a/examples/middleware/basic.py b/examples/middleware/basic.py
index 713bb1d..74bfab0 100644
--- a/examples/middleware/basic.py
+++ b/examples/middleware/basic.py
@@ -1,5 +1,6 @@
from micropie import App, HttpMiddleware
+
class MiddlewareExample(HttpMiddleware):
async def before_request(self, request):
print("Hook before request")
@@ -7,10 +8,12 @@ class MiddlewareExample(HttpMiddleware):
async def after_request(self, request, status_code, response_body, extra_headers):
print("Hook after request")
+
class Root(App):
async def index(self):
print("Hello, World!")
return "Hello, World!"
+
app = Root()
app.middlewares.append(MiddlewareExample())
diff --git a/examples/middleware/csrf.py b/examples/middleware/csrf.py
index e76763f..f54e5ee 100644
--- a/examples/middleware/csrf.py
+++ b/examples/middleware/csrf.py
@@ -32,7 +32,9 @@ class CSRFMiddleware(HttpMiddleware):
self.app = app
self.serializer = URLSafeTimedSerializer(secret_key, salt="csrf-token")
self.max_age = max_age
- self.trusted = set(trusted_origins or []) # e.g. ["https://gardenfresh.vegy.app"]
+ self.trusted = set(
+ trusted_origins or []
+ ) # e.g. ["https://gardenfresh.vegy.app"]
self.body_field = body_field
self.header_name = header_name.lower()
self.require_header_for_multipart = require_header_for_multipart
@@ -149,7 +151,6 @@ class CSRFMiddleware(HttpMiddleware):
class Root(App):
-
async def index(self):
csrf_token = self.request.session.get("csrf_token", "")
return f"""<form method="POST" action="/submit">
@@ -172,4 +173,3 @@ app.middlewares.append(
exempt_paths=["/sms_order"],
)
)
-
diff --git a/examples/middleware/rate_limit.py b/examples/middleware/rate_limit.py
index a61df36..46047f9 100644
--- a/examples/middleware/rate_limit.py
+++ b/examples/middleware/rate_limit.py
@@ -14,14 +14,14 @@ class MongoRateLimitMiddleware(HttpMiddleware):
- Permanent block if too many violations in a 24h period.
"""
- MAX_REQUESTS = 50 # allowed per window
- WINDOW_SECONDS = 60 # window length in seconds
+ MAX_REQUESTS = 50 # allowed per window
+ WINDOW_SECONDS = 60 # window length in seconds
- BLOCK_AFTER_VIOLATIONS = 3 # how many windows exceeded before temp block
- BLOCK_FOR_SECONDS = 900 # how long to temporarily block (seconds)
+ BLOCK_AFTER_VIOLATIONS = 3 # how many windows exceeded before temp block
+ BLOCK_FOR_SECONDS = 900 # how long to temporarily block (seconds)
- PERMA_WINDOW_HOURS = 24 # lookback window for permanent block
- PERMA_BLOCK_AFTER = 10 # violations in window before permanent block
+ PERMA_WINDOW_HOURS = 24 # lookback window for permanent block
+ PERMA_BLOCK_AFTER = 10 # violations in window before permanent block
def __init__(
self,
@@ -175,7 +175,6 @@ class MongoRateLimitMiddleware(HttpMiddleware):
class MyApp(App):
-
async def index(self):
if "visits" not in self.request.session:
self.request.session["visits"] = 1
diff --git a/examples/middleware/router.py b/examples/middleware/router.py
index 1e0d903..781aef1 100644
--- a/examples/middleware/router.py
+++ b/examples/middleware/router.py
@@ -1,7 +1,7 @@
"""
Example of how you can implement explicit routes using middleware.
For a comprehensive example of this (with route decorators and type
-checking see the `rest` example at
+checking see the `rest` example at
https://github.com/patx/micropie/tree/main/examples/rest
"""
@@ -9,15 +9,16 @@ import re
from typing import Dict, List, Optional, Tuple, Any
from micropie import App, HttpMiddleware, Request
+
class ExplicitRouter(HttpMiddleware):
def __init__(self):
# Map route paths to (method, regex pattern, handler name)
self.routes: Dict[str, Tuple[str, str, str]] = {}
-
+
def add_route(self, path: str, handler_name: str, method: str = "GET") -> None:
"""
Register an explicit route with its handler method name and HTTP method.
-
+
Args:
path: The route pattern (e.g., "/api/users/{user}/records/{record}")
handler_name: The handler method name (e.g., "api")
@@ -26,19 +27,19 @@ class ExplicitRouter(HttpMiddleware):
pattern = re.sub(r"{([^}]+)}", r"([^/]+)", path)
pattern = f"^{pattern}$"
self.routes[path] = (method, pattern, handler_name)
-
+
async def before_request(self, request: Request) -> Optional[Dict]:
"""
Match the request path and set path parameters for MicroPie routing.
-
+
Args:
request: The MicroPie Request object
-
+
Returns:
None to let MicroPie handle parsing and routing
"""
path = request.scope["path"]
-
+
for route_path, (method, pattern, handler_name) in self.routes.items():
if request.method != method:
continue
@@ -48,15 +49,15 @@ class ExplicitRouter(HttpMiddleware):
request.path_params = [str(param) for param in match.groups()]
request._route_handler = handler_name
return None
-
+
return None
-
+
async def after_request(
self,
request: Request,
status_code: int,
response_body: Any,
- extra_headers: List[Tuple[str, str]]
+ extra_headers: List[Tuple[str, str]],
) -> Optional[Dict]:
return None
@@ -66,33 +67,43 @@ class MyApp(App):
super().__init__()
self.router = ExplicitRouter()
self.middlewares.append(self.router)
-
+
# Register explicit routes
- self.router.add_route("/api/users/{user}/records/{record}", "_get_record", "GET")
+ self.router.add_route(
+ "/api/users/{user}/records/{record}", "_get_record", "GET"
+ )
self.router.add_route("/api/users/{user}/records", "_create_record", "POST")
- self.router.add_route("/api/users/{user}/records/{record}/details/subdetails", "_get_record_subdetails", "GET")
-
+ self.router.add_route(
+ "/api/users/{user}/records/{record}/details/subdetails",
+ "_get_record_subdetails",
+ "GET",
+ )
+
async def _get_record(self, user: str, record: str):
try:
record_id = int(record)
return {"user": user, "record": record_id}
except ValueError:
return {"error": "Record must be an integer"}
-
+
async def _create_record(self, user: str):
try:
data = self.request.get_json
return {"user": user, "record": data.get("record_id"), "created": True}
except Exception as e:
return {"error": f"Invalid JSON: {str(e)}"}
-
+
async def _get_record_subdetails(self, user: str, record: str):
try:
record_id = int(record)
- return {"user": user, "record": record_id, "subdetails": "more detailed info"}
+ return {
+ "user": user,
+ "record": record_id,
+ "subdetails": "more detailed info",
+ }
except ValueError:
return {"error": "Record must be an integer"}
-
+
# Implicitly routed
async def records(self, user: str, record: str):
try:
@@ -105,4 +116,5 @@ class MyApp(App):
async def _private(self):
return {"viewing": "private"}
+
app = MyApp()
diff --git a/examples/middleware/sessions.py b/examples/middleware/sessions.py
index 47518cc..e4fbcde 100644
--- a/examples/middleware/sessions.py
+++ b/examples/middleware/sessions.py
@@ -8,6 +8,7 @@ from micropie import App, HttpMiddleware, Request, SESSION_TIMEOUT
class SignedSessionMiddleware(HttpMiddleware):
"""Middleware to sign and verify session cookies using itsdangerous."""
+
def __init__(self, app: App, secret_key: str, max_age: int = SESSION_TIMEOUT):
self.app = app # Store the App instance
self.serializer = URLSafeTimedSerializer(secret_key)
@@ -25,7 +26,11 @@ class SignedSessionMiddleware(HttpMiddleware):
return None
async def after_request(
- self, request: Request, status_code: int, response_body: Any, extra_headers: List[Tuple[str, str]]
+ self,
+ request: Request,
+ status_code: int,
+ response_body: Any,
+ extra_headers: List[Tuple[str, str]],
) -> Optional[Dict]:
"""Sign and set the session_id cookie after processing the request."""
if request.session:
@@ -38,14 +43,20 @@ class SignedSessionMiddleware(HttpMiddleware):
signed_session_id = self.serializer.dumps(session_id)
current_session = await self.app.session_backend.load(session_id) or {}
if current_session != request.session:
- await self.app.session_backend.save(session_id, request.session, SESSION_TIMEOUT)
+ await self.app.session_backend.save(
+ session_id, request.session, SESSION_TIMEOUT
+ )
if not cookies.get("session_id"):
- extra_headers.append(("Set-Cookie", f"session_id={signed_session_id}; Path=/; SameSite=Lax; HttpOnly; Secure;"))
+ extra_headers.append(
+ (
+ "Set-Cookie",
+ f"session_id={signed_session_id}; Path=/; SameSite=Lax; HttpOnly; Secure;",
+ )
+ )
return None
class Root(App):
-
async def index(self):
if "visits" not in self.request.session:
self.request.session["visits"] = 1
diff --git a/examples/middleware/subapp.py b/examples/middleware/subapp.py
index 0439669..5726fa2 100644
--- a/examples/middleware/subapp.py
+++ b/examples/middleware/subapp.py
@@ -4,8 +4,10 @@ import uuid
from itsdangerous import URLSafeTimedSerializer, BadSignature
from micropie import App, HttpMiddleware, Request
+
class CSRFMiddleware(HttpMiddleware):
"""Middleware for CSRF protection using itsdangerous-signed tokens."""
+
def __init__(self, app: App, secret_key: str, max_age: int = 8 * 3600):
self.app = app # Store the App instance
self.serializer = URLSafeTimedSerializer(secret_key, salt="csrf-token")
@@ -14,7 +16,11 @@ class CSRFMiddleware(HttpMiddleware):
async def before_request(self, request: Request) -> Optional[Dict]:
"""Verify CSRF token for POST/PUT/PATCH requests and generate a new token if needed."""
# Extract session ID from cookies or generate a new one
- session_id = request.headers.get("cookie", "").split("session_id=")[-1].split(";")[0] if "session_id=" in request.headers.get("cookie", "") else str(uuid.uuid4())
+ session_id = (
+ request.headers.get("cookie", "").split("session_id=")[-1].split(";")[0]
+ if "session_id=" in request.headers.get("cookie", "")
+ else str(uuid.uuid4())
+ )
if request.method in ("POST", "PUT", "PATCH"):
print(f"Request body_params: {request.body_params}") # Debugging
@@ -35,18 +41,25 @@ class CSRFMiddleware(HttpMiddleware):
request.session["csrf_token"] = signed_token
# Save the session
print(f"Saving session with CSRF token: {signed_token}") # Debugging
- await self.app.session_backend.save(session_id, request.session, self.max_age)
+ await self.app.session_backend.save(
+ session_id, request.session, self.max_age
+ )
return None
async def after_request(
- self, request: Request, status_code: int, response_body: Any, extra_headers: List[Tuple[str, str]]
+ self,
+ request: Request,
+ status_code: int,
+ response_body: Any,
+ extra_headers: List[Tuple[str, str]],
) -> Optional[Dict]:
"""Include CSRF token in response headers for client-side use."""
if request.session.get("csrf_token"):
extra_headers.append(("X-CSRF-Token", request.session["csrf_token"]))
return None
-
+
+
# Define the Sub-App
class ApiApp(App):
async def index(self):
@@ -70,17 +83,19 @@ class ApiApp(App):
<input type="text" name="name">
<button type="submit">Submit</button>
</form>"""
-
+
async def plogin(self, name):
return f"Hello {name}"
+
class UserApp(App):
async def index(self):
return {"msg": "Hello world"}
-
+
async def hello(self, name="world"):
return f"hello {name}"
+
# Define a Middleware to Mount the Sub-App
class SubAppMiddleware(HttpMiddleware):
def __init__(self, mount_path: str, subapp: App):
@@ -92,14 +107,14 @@ class SubAppMiddleware(HttpMiddleware):
if path.startswith(self.mount_path):
# Set the subapp and the remaining path
request._subapp = self.subapp
- request._subapp_path = path[len(self.mount_path):].lstrip("/") or "/"
+ request._subapp_path = path[len(self.mount_path) :].lstrip("/") or "/"
return None # Continue processing
return None # Not a subapp path, continue with main app
- async def after_request(
- self, request, status_code, response_body, extra_headers):
+ async def after_request(self, request, status_code, response_body, extra_headers):
return None # No changes to response
+
# Define the Main App
class MainApp(App):
async def index(self):
@@ -108,6 +123,7 @@ class MainApp(App):
async def hello(self, name: str):
return {"message": f"Hello, {name} from Main App!"}
+
# Create and Configure the Apps
app = MainApp()
api_app = ApiApp()
diff --git a/examples/middleware/upload.py b/examples/middleware/upload.py
index eb7aa50..bcff848 100644
--- a/examples/middleware/upload.py
+++ b/examples/middleware/upload.py
@@ -3,6 +3,7 @@ import asyncio
MAX_UPLOAD_SIZE = 100 * 1024 * 1024 # 100MB
+
class MaxUploadSizeMiddleware(HttpMiddleware):
async def before_request(self, request):
# Check if we're dealing with a POST, PUT, or PATCH request
@@ -12,20 +13,22 @@ class MaxUploadSizeMiddleware(HttpMiddleware):
if content_length is None:
return {
"status_code": 400,
- "body": "400 Bad Request: Missing Content-Length header"
+ "body": "400 Bad Request: Missing Content-Length header",
}
try:
content_length = int(content_length)
if content_length > MAX_UPLOAD_SIZE:
- print(f"Upload rejected: Content-Length ({content_length}) exceeds {MAX_UPLOAD_SIZE} bytes")
+ print(
+ f"Upload rejected: Content-Length ({content_length}) exceeds {MAX_UPLOAD_SIZE} bytes"
+ )
return {
"status_code": 413,
- "body": "413 Payload Too Large: Uploaded file exceeds size limit."
+ "body": "413 Payload Too Large: Uploaded file exceeds size limit.",
}
except ValueError:
return {
"status_code": 400,
- "body": "400 Bad Request: Invalid Content-Length header"
+ "body": "400 Bad Request: Invalid Content-Length header",
}
# Continue processing if checks pass
return None
@@ -49,7 +52,7 @@ class FileUploadApp(App):
</form>
</body>
</html>""",
- [("Content-Type", "text/html; charset=utf-8")]
+ [("Content-Type", "text/html; charset=utf-8")],
)
async def upload(self, file):
@@ -67,11 +70,7 @@ class FileUploadApp(App):
total_size += len(chunk)
# Example: Process chunk (e.g., save to disk, validate, etc.)
# For demonstration, just count the size
- return {
- "filename": filename,
- "content_type": content_type,
- "size": total_size
- }
+ return {"filename": filename, "content_type": content_type, "size": total_size}
app = FileUploadApp()
diff --git a/examples/pastebin/app.py b/examples/pastebin/app.py
index 73bbbde..2adc87d 100644
--- a/examples/pastebin/app.py
+++ b/examples/pastebin/app.py
@@ -6,7 +6,6 @@ pastes = Mkv("mongodb://localhost:27017")
class Root(App):
-
async def index(self, paste_content=None):
if self.request.method == "POST":
new_id = await pastes.set(None, paste_content)
@@ -16,11 +15,11 @@ class Root(App):
async def paste(self, paste_id):
paste = await pastes.get(paste_id, "404: Paste Not Found")
- return await self._render_template("paste.html",
+ return await self._render_template(
+ "paste.html",
paste_id=paste_id,
paste_content=paste,
)
app = Root()
-
diff --git a/examples/server_sent_events/app.py b/examples/server_sent_events/app.py
index bd469f0..e412d8d 100644
--- a/examples/server_sent_events/app.py
+++ b/examples/server_sent_events/app.py
@@ -4,7 +4,6 @@ from micropie import App
class MyApp(App):
-
async def index(self):
return """
<!DOCTYPE html>
@@ -48,8 +47,10 @@ class MyApp(App):
# Return status code, async generator, and headers for SSE
return 200, event_generator(), [("Content-Type", "text/event-stream")]
+
app = MyApp()
if __name__ == "__main__":
import uvicorn
+
uvicorn.run(app, host="0.0.0.0", port=8000)
diff --git a/examples/server_sent_events/chat.py b/examples/server_sent_events/chat.py
index 5dcdbc5..8844b6d 100644
--- a/examples/server_sent_events/chat.py
+++ b/examples/server_sent_events/chat.py
@@ -3,6 +3,7 @@ from collections import deque
import json
import asyncio
+
class ChatApp(App):
def __init__(self):
super().__init__()
@@ -47,7 +48,7 @@ class ChatApp(App):
};
</script>
"""
- return 200, html, [('Content-Type', 'text/html')]
+ return 200, html, [("Content-Type", "text/html")]
async def send(self):
data = self.request.get_json
@@ -81,10 +82,16 @@ class ChatApp(App):
break
finally:
self.clients.discard(queue)
- return 200, stream(), [
- ('Content-Type', 'text/event-stream'),
- ('Cache-Control', 'no-cache'),
- ('Connection', 'keep-alive')
- ]
+
+ return (
+ 200,
+ stream(),
+ [
+ ("Content-Type", "text/event-stream"),
+ ("Cache-Control", "no-cache"),
+ ("Connection", "keep-alive"),
+ ],
+ )
+
app = ChatApp()
diff --git a/examples/sessions/in_memory.py b/examples/sessions/in_memory.py
index fb82ff5..86d2ac3 100644
--- a/examples/sessions/in_memory.py
+++ b/examples/sessions/in_memory.py
@@ -1,5 +1,6 @@
from micropie import App
+
class MyApp(App):
async def index(self):
# Use self.request.session to access the session data.
@@ -9,4 +10,5 @@ class MyApp(App):
self.request.session["visits"] += 1
return f"You have visited {self.request.session['visits']} times."
+
app = MyApp()
diff --git a/examples/sessions/motor_backend.py b/examples/sessions/motor_backend.py
index ffa050a..83ed4fe 100644
--- a/examples/sessions/motor_backend.py
+++ b/examples/sessions/motor_backend.py
@@ -101,4 +101,3 @@ backend = MotorSessionBackend(MONGO_URI, DB_NAME)
# Pass the Motor session backend to our application.
app = MyApp(session_backend=backend)
-
diff --git a/examples/socketio/basic/chatroom.py b/examples/socketio/basic/chatroom.py
index 53f247d..375a564 100644
--- a/examples/socketio/basic/chatroom.py
+++ b/examples/socketio/basic/chatroom.py
@@ -13,11 +13,13 @@ sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*")
# Store connected users
connected_users = {}
+
# Create the MicroPie server
class MyApp(App):
async def index(self):
return await self._render_template("chat.html")
+
# Socket.IO event handlers
@sio.event
async def connect(sid, environ):
@@ -25,15 +27,20 @@ async def connect(sid, environ):
# Add user with temporary username
connected_users[sid] = f"User_{sid[:4]}"
await update_user_list()
-
+
# Send recent messages to the newly connected client
recent_messages = db.search("type", "message", limit=50)
for msg in recent_messages:
- await sio.emit('message', {
- 'username': msg['username'],
- 'message': msg['message'],
- 'timestamp': msg['timestamp']
- }, room=sid)
+ await sio.emit(
+ "message",
+ {
+ "username": msg["username"],
+ "message": msg["message"],
+ "timestamp": msg["timestamp"],
+ },
+ room=sid,
+ )
+
@sio.event
async def disconnect(sid):
@@ -43,45 +50,53 @@ async def disconnect(sid):
del connected_users[sid]
await update_user_list()
+
@sio.event
async def set_username(sid, data):
"""Handle username setting"""
- username = data.get('username', '').strip()
+ username = data.get("username", "").strip()
if username and len(username) <= 20: # Basic validation
if username in connected_users.values():
- await sio.emit('error', {'message': 'Invalid username'}, room=sid)
+ await sio.emit("error", {"message": "Invalid username"}, room=sid)
connected_users[sid] = username
print(f"User {sid} set username to {username}")
await update_user_list()
else:
- await sio.emit('error', {'message': 'Invalid username'}, room=sid)
+ await sio.emit("error", {"message": "Invalid username"}, room=sid)
+
@sio.event
async def message(sid, data):
"""Handle and store messages"""
username = connected_users.get(sid, f"User_{sid[:4]}")
- message = data.get('message', '').strip()
-
+ message = data.get("message", "").strip()
+
if message:
# Store message in KenobiDB
message_doc = {
- 'type': 'message',
- 'username': username,
- 'message': message,
- 'timestamp': datetime.utcnow().isoformat()
+ "type": "message",
+ "username": username,
+ "message": message,
+ "timestamp": datetime.utcnow().isoformat(),
}
db.insert(message_doc)
-
+
# Broadcast message to all clients
- await sio.emit('message', {
- 'username': username,
- 'message': message,
- 'timestamp': message_doc['timestamp']
- }, room=None)
+ await sio.emit(
+ "message",
+ {
+ "username": username,
+ "message": message,
+ "timestamp": message_doc["timestamp"],
+ },
+ room=None,
+ )
+
async def update_user_list():
"""Broadcast updated user list to all clients"""
- await sio.emit('user_list', list(connected_users.values()), room=None)
+ await sio.emit("user_list", list(connected_users.values()), room=None)
+
# Attach Socket.IO to the ASGI app
asgi_app = MyApp()
@@ -89,4 +104,5 @@ app = socketio.ASGIApp(sio, asgi_app)
# Ensure database is closed properly on shutdown
import atexit
+
atexit.register(db.close)
diff --git a/examples/socketio/basic/webcam.py b/examples/socketio/basic/webcam.py
index 45d11cd..3e0865f 100644
--- a/examples/socketio/basic/webcam.py
+++ b/examples/socketio/basic/webcam.py
@@ -7,6 +7,7 @@ sio = socketio.AsyncServer(async_mode="asgi")
# Track active users and their watchers/streamers
active_users = set()
+
# MicroPie Server with integrated Socket.IO
class MyApp(App):
async def index(self):
@@ -15,25 +16,40 @@ class MyApp(App):
async def submit(self, username: str, action: str):
if username:
active_users.add(username)
- route = f"/stream/{username}" if action == "Start Streaming" else f"/watch/{username}"
+ route = (
+ f"/stream/{username}"
+ if action == "Start Streaming"
+ else f"/watch/{username}"
+ )
return self.redirect(route)
return self.redirect("/")
async def stream(self, username: str):
- return await self.render_template("stream.html", username=username) if username in active_users else self.redirect("/")
+ return (
+ await self.render_template("stream.html", username=username)
+ if username in active_users
+ else self.redirect("/")
+ )
async def watch(self, username: str):
- return await self.render_template("watch.html", username=username) if username in active_users else self.redirect("/")
+ return (
+ await self.render_template("watch.html", username=username)
+ if username in active_users
+ else self.redirect("/")
+ )
+
# Socket.IO event handlers
@sio.event
async def connect(sid, environ):
print(f"Client connected: {sid}")
+
@sio.event
async def disconnect(sid):
print(f"Client disconnected: {sid}")
+
@sio.on("stream_frame")
async def handle_stream_frame(sid, data):
"""
@@ -50,6 +66,7 @@ async def handle_stream_frame(sid, data):
room=username,
)
+
@sio.on("join_room")
async def join_room(sid, data):
"""Add a client to a room (either as a streamer or watcher)."""
@@ -58,6 +75,7 @@ async def join_room(sid, data):
await sio.enter_room(sid, username) # Await the method
print(f"{sid} joined room for {username}")
+
@sio.on("leave_room")
async def leave_room(sid, data):
"""Remove a client from a room."""
@@ -66,6 +84,7 @@ async def leave_room(sid, data):
sio.leave_room(sid, username)
print(f"{sid} left room for {username}")
+
# Attach the Socket.IO server to the ASGI app
asgi_app = MyApp()
app = socketio.ASGIApp(sio, asgi_app)
diff --git a/examples/socketio/chatroom/app.py b/examples/socketio/chatroom/app.py
index 2808209..3db53d8 100644
--- a/examples/socketio/chatroom/app.py
+++ b/examples/socketio/chatroom/app.py
@@ -18,6 +18,7 @@ connected_users = {}
# Store channel passwords {channel: password or None}
channel_passwords = {}
+
# Create the MicroPie server
class MyApp(App):
async def index(self):
@@ -28,15 +29,19 @@ class MyApp(App):
"""Render specific channel page"""
channel_name = unquote(channel_name)
# Validate channel name
- if not re.match(r'^[a-zA-Z0-9_-]{1,30}$', channel_name):
+ if not re.match(r"^[a-zA-Z0-9_-]{1,30}$", channel_name):
return {"error": "Invalid channel name"}, 400
- return await self._render_template("chat.html", channel=channel_name, is_index=False)
+ return await self._render_template(
+ "chat.html", channel=channel_name, is_index=False
+ )
+
# Socket.IO event handlers
@sio.event
async def connect(sid, environ):
print(f"Client connected: {sid}")
+
@sio.event
async def disconnect(sid):
print(f"Client disconnected: {sid}")
@@ -46,19 +51,20 @@ async def disconnect(sid):
del connected_users[channel][sid]
await update_user_list(channel)
+
@sio.event
async def join_channel(sid, data):
"""Handle joining a channel with optional password"""
- channel = data.get('channel', '').strip()
- password = data.get('password', '')
- username = data.get('username', '').strip()
+ channel = data.get("channel", "").strip()
+ password = data.get("password", "")
+ username = data.get("username", "").strip()
- if not channel or not re.match(r'^[a-zA-Z0-9_-]{1,30}$', channel):
- await sio.emit('error', {'message': 'Invalid channel name'}, room=sid)
+ if not channel or not re.match(r"^[a-zA-Z0-9_-]{1,30}$", channel):
+ await sio.emit("error", {"message": "Invalid channel name"}, room=sid)
return
if not username or len(username) > 20:
- await sio.emit('error', {'message': 'Invalid username'}, room=sid)
+ await sio.emit("error", {"message": "Invalid username"}, room=sid)
return
# Initialize channel if it doesn't exist
@@ -66,18 +72,22 @@ async def join_channel(sid, data):
connected_users[channel] = {}
# Check password if required
- if channel in channel_passwords and channel_passwords[channel] and password != channel_passwords[channel]:
- await sio.emit('error', {'message': 'Incorrect password'}, room=sid)
+ if (
+ channel in channel_passwords
+ and channel_passwords[channel]
+ and password != channel_passwords[channel]
+ ):
+ await sio.emit("error", {"message": "Incorrect password"}, room=sid)
return
# Check if username is taken in this channel
if username in connected_users[channel].values():
- await sio.emit('error', {'message': 'Username already taken'}, room=sid)
+ await sio.emit("error", {"message": "Username already taken"}, room=sid)
return
# Join the socket.io room for this channel
await sio.enter_room(sid, channel)
-
+
# Add user to channel
connected_users[channel][sid] = username
print(f"User {sid} joined channel {channel} as {username}")
@@ -85,51 +95,61 @@ async def join_channel(sid, data):
# Send recent messages for this channel
try:
# Get all messages with type="message"
- all_messages = db.search("type", "message", limit=1000) # Use a high limit to ensure we get all messages
+ all_messages = db.search(
+ "type", "message", limit=1000
+ ) # Use a high limit to ensure we get all messages
# Filter messages for the specific channel
- recent_messages = [msg for msg in all_messages if msg.get('channel') == channel][:50] # Limit to 50
+ recent_messages = [
+ msg for msg in all_messages if msg.get("channel") == channel
+ ][:50] # Limit to 50
print(f"Retrieved {len(recent_messages)} recent messages for channel {channel}")
for msg in recent_messages:
- await sio.emit('message', {
- 'username': msg['username'],
- 'message': msg['message'],
- 'timestamp': msg['timestamp']
- }, room=sid)
+ await sio.emit(
+ "message",
+ {
+ "username": msg["username"],
+ "message": msg["message"],
+ "timestamp": msg["timestamp"],
+ },
+ room=sid,
+ )
except Exception as e:
print(f"Error retrieving messages for channel {channel}: {str(e)}")
- await sio.emit('error', {'message': 'Failed to load recent messages'}, room=sid)
+ await sio.emit("error", {"message": "Failed to load recent messages"}, room=sid)
# Update user list for this channel
await update_user_list(channel)
# Confirm successful join to the client
- await sio.emit('join_success', room=sid)
+ await sio.emit("join_success", room=sid)
+
@sio.event
async def create_channel(sid, data):
"""Handle channel creation"""
- channel = data.get('channel', '').strip()
- password = data.get('password', '').strip() or None
+ channel = data.get("channel", "").strip()
+ password = data.get("password", "").strip() or None
- if not channel or not re.match(r'^[a-zA-Z0-9_-]{1,30}$', channel):
- await sio.emit('error', {'message': 'Invalid channel name'}, room=sid)
+ if not channel or not re.match(r"^[a-zA-Z0-9_-]{1,30}$", channel):
+ await sio.emit("error", {"message": "Invalid channel name"}, room=sid)
return
if channel in connected_users:
- await sio.emit('error', {'message': 'Channel already exists'}, room=sid)
+ await sio.emit("error", {"message": "Channel already exists"}, room=sid)
return
# Create new channel
connected_users[channel] = {}
channel_passwords[channel] = password
- await sio.emit('channel_created', {'channel': channel}, room=sid)
+ await sio.emit("channel_created", {"channel": channel}, room=sid)
+
@sio.event
async def message(sid, data):
"""Handle and store messages"""
- channel = data.get('channel', '').strip()
- message = data.get('message', '').strip()
-
+ channel = data.get("channel", "").strip()
+ message = data.get("message", "").strip()
+
if not channel or not message:
return
@@ -144,26 +164,36 @@ async def message(sid, data):
# Store message in KenobiDB
message_doc = {
- 'type': 'message',
- 'channel': channel,
- 'username': username,
- 'message': message,
- 'timestamp': datetime.utcnow().isoformat()
+ "type": "message",
+ "channel": channel,
+ "username": username,
+ "message": message,
+ "timestamp": datetime.utcnow().isoformat(),
}
db.insert(message_doc)
-
+
# Broadcast message to channel
- await sio.emit('message', {
- 'username': username,
- 'message': message,
- 'timestamp': message_doc['timestamp']
- }, room=channel)
+ await sio.emit(
+ "message",
+ {
+ "username": username,
+ "message": message,
+ "timestamp": message_doc["timestamp"],
+ },
+ room=channel,
+ )
+
async def update_user_list(channel):
"""Broadcast updated user list to channel"""
if channel in connected_users:
- print(f"Broadcasting user list for channel {channel}: {list(connected_users[channel].values())}")
- await sio.emit('user_list', list(connected_users[channel].values()), room=channel)
+ print(
+ f"Broadcasting user list for channel {channel}: {list(connected_users[channel].values())}"
+ )
+ await sio.emit(
+ "user_list", list(connected_users[channel].values()), room=channel
+ )
+
# Attach Socket.IO to the ASGI app
asgi_app = MyApp()
@@ -171,4 +201,5 @@ app = socketio.ASGIApp(sio, asgi_app)
# Ensure database is closed properly on shutdown
import atexit
+
atexit.register(db.close)
diff --git a/examples/socketio/webtrc/basic/app.py b/examples/socketio/webtrc/basic/app.py
index fafaa2d..cb48c44 100644
--- a/examples/socketio/webtrc/basic/app.py
+++ b/examples/socketio/webtrc/basic/app.py
@@ -10,16 +10,17 @@ active_users = set()
# Map session IDs to usernames for cleanup on disconnect
sid_to_username = {}
+
# 2) Create a MicroPie server class with routes
class MyApp(App):
async def index(self):
# A simple response for the root path
- return 'Use /stream/<room name here> or /watch/<room name here>'
+ return "Use /stream/<room name here> or /watch/<room name here>"
async def stream(self, username: str):
# Check if the username is already actively streaming
if username in active_users:
- return 403, {'error': f'Username {username} is already actively streaming'}
+ return 403, {"error": f"Username {username} is already actively streaming"}
# Mark the username active, render the streamer template
active_users.add(username)
return await self._render_template("stream.html", username=username)
@@ -28,14 +29,17 @@ class MyApp(App):
# Render the watcher template (no need to mark as active here since it's handled in join_room)
return await self._render_template("watch.html", username=username)
+
#
# ------------------- Socket.IO Events for Signaling --------------------
#
+
@sio.event
async def connect(sid, environ):
print(f"[connect] Client connected: {sid}")
+
@sio.event
async def disconnect(sid):
print(f"[disconnect] Client disconnected: {sid}")
@@ -45,6 +49,7 @@ async def disconnect(sid):
active_users.discard(username)
print(f"[disconnect] Removed {username} from active_users")
+
@sio.on("join_room")
async def join_room(sid, data):
"""Each client (streamer or watcher) joins a room named after <username>."""
@@ -55,6 +60,7 @@ async def join_room(sid, data):
await sio.enter_room(sid, username)
print(f"[join_room] {sid} joined room '{username}'")
+
@sio.on("new_watcher")
async def new_watcher(sid, data):
"""
@@ -67,10 +73,13 @@ async def new_watcher(sid, data):
print(f"[new_watcher] {watcher_sid} => watch {username}")
if username in active_users:
# Notify others in the room (specifically the streamer)
- await sio.emit("new_watcher",
- {"watcherSid": watcher_sid},
- room=username,
- skip_sid=watcher_sid)
+ await sio.emit(
+ "new_watcher",
+ {"watcherSid": watcher_sid},
+ room=username,
+ skip_sid=watcher_sid,
+ )
+
@sio.on("offer")
async def handle_offer(sid, data):
@@ -86,13 +95,12 @@ async def handle_offer(sid, data):
print(f"[offer] From streamer {sid} to watcher {watcher_sid}, room={username}")
# Send the offer ONLY to watcherSid (not the whole room)
- await sio.emit("offer",
- {
- "offer": offer_sdp,
- "offerType": offer_type,
- "streamerSid": sid
- },
- to=watcher_sid)
+ await sio.emit(
+ "offer",
+ {"offer": offer_sdp, "offerType": offer_type, "streamerSid": sid},
+ to=watcher_sid,
+ )
+
@sio.on("answer")
async def handle_answer(sid, data):
@@ -106,13 +114,12 @@ async def handle_answer(sid, data):
print(f"[answer] From watcher {sid} to streamer {streamer_sid}")
- await sio.emit("answer",
- {
- "answer": answer_sdp,
- "answerType": answer_type,
- "watcherSid": sid
- },
- to=streamer_sid)
+ await sio.emit(
+ "answer",
+ {"answer": answer_sdp, "answerType": answer_type, "watcherSid": sid},
+ to=streamer_sid,
+ )
+
@sio.on("ice-candidate")
async def handle_ice_candidate(sid, data):
@@ -128,14 +135,17 @@ async def handle_ice_candidate(sid, data):
print(f"[ice-candidate] {sid} => {target_sid}")
if target_sid:
- await sio.emit("ice-candidate",
- {
- "candidate": candidate,
- "sdpMid": sdp_mid,
- "sdpMLineIndex": sdp_mline_index,
- "senderSid": sid
- },
- to=target_sid)
+ await sio.emit(
+ "ice-candidate",
+ {
+ "candidate": candidate,
+ "sdpMid": sdp_mid,
+ "sdpMLineIndex": sdp_mline_index,
+ "senderSid": sid,
+ },
+ to=target_sid,
+ )
+
asgi_app = MyApp()
app = socketio.ASGIApp(sio, asgi_app)
diff --git a/examples/socketio/webtrc/heroku_ready/app.py b/examples/socketio/webtrc/heroku_ready/app.py
index 5108683..4ead9e4 100644
--- a/examples/socketio/webtrc/heroku_ready/app.py
+++ b/examples/socketio/webtrc/heroku_ready/app.py
@@ -7,7 +7,7 @@ from mongokv import Mkv
# ---------------- Config ----------------
ALLOWED_ORIGINS = [
"http://localhost:8000",
- "http://127.0.0.1:8000", # add your domain here
+ "http://127.0.0.1:8000", # add your domain here
]
MONGO_URI = "mongodb://localhost:27017"
@@ -39,12 +39,15 @@ class MyApp(App):
def _k_sid(username: str) -> str:
return f"streamer_sid:{username}"
+
def _k_seen(username: str) -> str:
return f"streamer_seen:{username}"
+
def _k_user_by_sid(sid: str) -> str:
return f"streamer_user_by_sid:{sid}"
+
def _k_token(username: str) -> str:
return f"streamer_token:{username}"
@@ -84,7 +87,9 @@ async def streamer_still_owner(username: str, sid: str) -> bool:
return current == sid
-async def claim_stream_username(username: str, sid: str, stream_token: str | None) -> tuple[bool, str | None]:
+async def claim_stream_username(
+ username: str, sid: str, stream_token: str | None
+) -> tuple[bool, str | None]:
"""
Returns (ok, reason).
- If a fresh streamer exists and token doesn't match -> deny "taken"
@@ -153,7 +158,9 @@ async def join_room(sid, data):
if role == "streamer":
ok, reason = await claim_stream_username(username, sid, stream_token)
if not ok:
- await sio.emit("stream_denied", {"username": username, "reason": reason}, to=sid)
+ await sio.emit(
+ "stream_denied", {"username": username, "reason": reason}, to=sid
+ )
await sio.leave_room(sid, username)
await sio.disconnect(sid)
print(f"[join_room] DENIED streamer '{username}' to {sid} reason={reason}")
@@ -200,7 +207,9 @@ async def handle_offer(sid, data):
# Only current owner can send offers for this username
if not username or not await streamer_still_owner(username, sid):
- await sio.emit("stream_denied", {"username": username, "reason": "not_owner"}, to=sid)
+ await sio.emit(
+ "stream_denied", {"username": username, "reason": "not_owner"}, to=sid
+ )
return
if watcher_sid:
@@ -237,7 +246,9 @@ async def handle_ice_candidate(sid, data):
role = (data or {}).get("role")
if role == "streamer" and username:
if not await streamer_still_owner(username, sid):
- await sio.emit("stream_denied", {"username": username, "reason": "not_owner"}, to=sid)
+ await sio.emit(
+ "stream_denied", {"username": username, "reason": "not_owner"}, to=sid
+ )
return
if target_sid:
@@ -255,4 +266,3 @@ async def handle_ice_candidate(sid, data):
asgi_app = MyApp()
app = socketio.ASGIApp(sio, other_asgi_app=asgi_app)
-
diff --git a/examples/static_content/basic.py b/examples/static_content/basic.py
index e578c47..68d0441 100644
--- a/examples/static_content/basic.py
+++ b/examples/static_content/basic.py
@@ -3,8 +3,8 @@ import os
import aiofiles
import mimetypes
-class Root(App):
+class Root(App):
async def static(self, path):
# Normalize the file path to prevent directory traversal
file_path = os.path.normpath(os.path.join("static", path))
@@ -28,4 +28,5 @@ class Root(App):
return 200, stream_file(), [("Content-Type", content_type)]
return 404, "Not Found", []
+
app = Root()
diff --git a/examples/static_content/servestatic.py b/examples/static_content/servestatic.py
index 0600329..6806330 100644
--- a/examples/static_content/servestatic.py
+++ b/examples/static_content/servestatic.py
@@ -1,10 +1,12 @@
from servestatic import ServeStaticASGI
from micropie import App
+
class Root(App):
async def index(self):
return "Hello, World!"
+
# Create the application
application = Root()
diff --git a/examples/streaming/text.py b/examples/streaming/text.py
index 0c2d970..4ffe264 100644
--- a/examples/streaming/text.py
+++ b/examples/streaming/text.py
@@ -2,8 +2,8 @@ import time
import asyncio
from micropie import App
-class Root(App):
+class Root(App):
def index(self):
# Normal, immediate response (non-streaming)
return "Hello from index!"
@@ -14,8 +14,8 @@ class Root(App):
for i in range(1, 6):
yield f"Chunk {i} "
await asyncio.sleep(1)
- return generator()
+ return generator()
app = Root()
diff --git a/examples/streaming/video.py b/examples/streaming/video.py
index 10134aa..44d6b34 100644
--- a/examples/streaming/video.py
+++ b/examples/streaming/video.py
@@ -3,9 +3,10 @@ from micropie import App
VIDEO_PATH = "video.mp4"
+
class Root(App):
def index(self):
- return '''
+ return """
<html>
<body>
<center>
@@ -16,15 +17,15 @@ class Root(App):
</center>
</body>
</html>
- '''
+ """
async def stream(self):
# Access the request headers using the self.request property
headers = {
- k.decode('latin-1').lower(): v.decode('latin-1')
- for k, v in self.request.scope.get('headers', [])
+ k.decode("latin-1").lower(): v.decode("latin-1")
+ for k, v in self.request.scope.get("headers", [])
}
- range_header = headers.get('range')
+ range_header = headers.get("range")
file_size = os.path.getsize(VIDEO_PATH)
# Decide on start/end
@@ -72,4 +73,5 @@ class Root(App):
return (status_code, file_chunk_generator(start, end), extra_headers)
+
app = Root()
diff --git a/examples/twutr/app.py b/examples/twutr/app.py
index cc12d40..31e2974 100644
--- a/examples/twutr/app.py
+++ b/examples/twutr/app.py
@@ -14,13 +14,13 @@ from micropie import App, SessionBackend
import motor.motor_asyncio
from motor.motor_asyncio import AsyncIOMotorCollection
import bcrypt
-from bson import Binary # the bson package that is included with pymongo
+from bson import Binary # the bson package that is included with pymongo
# ------------------------------------------------------------------------------
# MongoDB Configuration (using Motor)
# ------------------------------------------------------------------------------
-MONGO_URI = ("YOUR URI HERE OR IMPLEMENT ENV VARS")
+MONGO_URI = "YOUR URI HERE OR IMPLEMENT ENV VARS"
client = motor.motor_asyncio.AsyncIOMotorClient(MONGO_URI)
db = client["twutr"]
user_collection = db["users"]
@@ -40,7 +40,7 @@ def hash_password(password: str) -> bytes:
The hashed password.
"""
salt = bcrypt.gensalt()
- return bcrypt.hashpw(password.encode('utf-8'), salt)
+ return bcrypt.hashpw(password.encode("utf-8"), salt)
def check_password(password: str, hashed: Binary) -> bool:
@@ -56,7 +56,7 @@ def check_password(password: str, hashed: Binary) -> bool:
"""
if isinstance(hashed, Binary):
hashed = bytes(hashed) # Convert Binary to bytes
- return bcrypt.checkpw(password.encode('utf-8'), hashed)
+ return bcrypt.checkpw(password.encode("utf-8"), hashed)
# ------------------------------------------------------------------------------
@@ -68,11 +68,12 @@ class MotorSessionBackend(SessionBackend):
Stores session data with an expiration timestamp.
"""
- def __init__(self, mongo_uri: str, db_name: str, collection_name: str = "sessions") -> None:
+ def __init__(
+ self, mongo_uri: str, db_name: str, collection_name: str = "sessions"
+ ) -> None:
self.db = client[db_name]
self.collection = self.db[collection_name]
-
async def load(self, session_id: str) -> Dict[str, Any]:
"""
Load session data from MongoDB. If expired, delete it and return an empty dict.
@@ -85,7 +86,6 @@ class MotorSessionBackend(SessionBackend):
return {}
return doc.get("data", {})
-
async def save(self, session_id: str, data: Dict[str, Any], timeout: int) -> None:
"""
Save session data into MongoDB with a specific timeout.
@@ -94,7 +94,7 @@ class MotorSessionBackend(SessionBackend):
await self.collection.update_one(
{"_id": session_id},
{"$set": {"data": data, "expires_at": expires_at}},
- upsert=True
+ upsert=True,
)
@@ -105,7 +105,7 @@ async def get_user_data(username: str) -> Optional[Dict[str, Any]]:
"""
Retrieve user data by username.
"""
- found_doc = await user_collection.find_one({'username': username})
+ found_doc = await user_collection.find_one({"username": username})
return found_doc if found_doc else None
@@ -113,34 +113,39 @@ async def save_user_data(username: str, data: Dict[str, Any]) -> None:
"""
Save updated user data back to the database with upsert.
"""
- await user_collection.update_one({'username': username}, {'$set': data}, upsert=True)
+ await user_collection.update_one(
+ {"username": username}, {"$set": data}, upsert=True
+ )
-def sort_messages_by_timestamp(messages: List[Tuple[str, str, str]],
- timestamp_index: int) -> List[Tuple[str, str, str]]:
+def sort_messages_by_timestamp(
+ messages: List[Tuple[str, str, str]], timestamp_index: int
+) -> List[Tuple[str, str, str]]:
"""
Sort a list of messages based on a timestamp field.
"""
return sorted(
messages,
- key=lambda x: datetime.strptime(x[timestamp_index], '%m/%d/%Y %I:%M %p'),
- reverse=True
+ key=lambda x: datetime.strptime(x[timestamp_index], "%m/%d/%Y %I:%M %p"),
+ reverse=True,
)
-async def get_all_messages_for_user_and_following(user_id: str) -> List[Tuple[str, str, str]]:
+async def get_all_messages_for_user_and_following(
+ user_id: str,
+) -> List[Tuple[str, str, str]]:
"""
Retrieve messages for the given user and the users they follow.
"""
all_messages: List[Tuple[str, str, str]] = []
user_data = await get_user_data(user_id)
if user_data:
- for message in user_data.get('messages', []):
+ for message in user_data.get("messages", []):
all_messages.append((user_id, message[0], message[1]))
- for following in user_data.get('following', []):
+ for following in user_data.get("following", []):
followed_user_data = await get_user_data(following)
if followed_user_data:
- for message in followed_user_data.get('messages', []):
+ for message in followed_user_data.get("messages", []):
all_messages.append((following, message[0], message[1]))
return all_messages
@@ -152,13 +157,15 @@ async def get_all_messages_from_all_users() -> List[Tuple[str, str, str]]:
all_messages: List[Tuple[str, str, str]] = []
cursor = user_collection.find({})
async for user_data in cursor:
- if 'messages' in user_data:
- for message in user_data['messages']:
- all_messages.append((user_data['username'], message[0], message[1]))
+ if "messages" in user_data:
+ for message in user_data["messages"]:
+ all_messages.append((user_data["username"], message[0], message[1]))
return all_messages
-async def update_follow_relationship(current_user: str, target_username: str, follow: bool = True) -> None:
+async def update_follow_relationship(
+ current_user: str, target_username: str, follow: bool = True
+) -> None:
"""
Update the follow/unfollow relationship between the current user and target user.
"""
@@ -166,37 +173,43 @@ async def update_follow_relationship(current_user: str, target_username: str, fo
target_user_data = await get_user_data(target_username)
if current_user_data and target_user_data:
if follow:
- if target_username not in current_user_data.get('following', []):
- current_user_data.setdefault('following', []).append(target_username)
- if current_user not in target_user_data.get('followers', []):
- target_user_data.setdefault('followers', []).append(current_user)
+ if target_username not in current_user_data.get("following", []):
+ current_user_data.setdefault("following", []).append(target_username)
+ if current_user not in target_user_data.get("followers", []):
+ target_user_data.setdefault("followers", []).append(current_user)
else:
- if target_username in current_user_data.get('following', []):
- current_user_data['following'].remove(target_username)
- if current_user in target_user_data.get('followers', []):
- target_user_data['followers'].remove(current_user)
+ if target_username in current_user_data.get("following", []):
+ current_user_data["following"].remove(target_username)
+ if current_user in target_user_data.get("followers", []):
+ target_user_data["followers"].remove(current_user)
await save_user_data(current_user, current_user_data)
await save_user_data(target_username, target_user_data)
-async def get_most_recent_messages(user_collection: AsyncIOMotorCollection, limit: int = 200) -> List[Tuple[str, str, str]]:
+async def get_most_recent_messages(
+ user_collection: AsyncIOMotorCollection, limit: int = 200
+) -> List[Tuple[str, str, str]]:
"""
Retrieve the most recent messages using an aggregation pipeline.
"""
pipeline = [
{"$match": {"messages": {"$exists": True, "$ne": []}}},
{"$unwind": "$messages"},
- {"$project": {
- "username": 1,
- "message_text": {"$arrayElemAt": ["$messages", 0]},
- "message_timestamp": {"$arrayElemAt": ["$messages", 1]}
- }},
+ {
+ "$project": {
+ "username": 1,
+ "message_text": {"$arrayElemAt": ["$messages", 0]},
+ "message_timestamp": {"$arrayElemAt": ["$messages", 1]},
+ }
+ },
{"$sort": {"message_timestamp": -1}},
- {"$limit": limit}
+ {"$limit": limit},
]
results = []
async for doc in user_collection.aggregate(pipeline):
- results.append((doc.get("username"), doc.get("message_text"), doc.get("message_timestamp")))
+ results.append(
+ (doc.get("username"), doc.get("message_text"), doc.get("message_timestamp"))
+ )
return results
@@ -204,8 +217,8 @@ def convert_custom_syntax(text: str) -> str:
"""
Convert custom user syntax into clickable HTML links.
"""
- link_pattern = r'@((?:https?:\/\/)?[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}(?:\/\S*)?)'
- internal_pattern = r'@(/[\w\-/]+)'
+ link_pattern = r"@((?:https?:\/\/)?[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}(?:\/\S*)?)"
+ internal_pattern = r"@(/[\w\-/]+)"
# Escape the text to prevent any dangerous HTML content
escaped_text = escape(text)
@@ -213,11 +226,10 @@ def convert_custom_syntax(text: str) -> str:
# Process the custom syntax for links
def replace_link(match: re.Match) -> str:
url = match.group(1)
- if not url.startswith('http'):
- url = f'http://{url}'
+ if not url.startswith("http"):
+ url = f"http://{url}"
return f'<a href="{url}" target="_blank">{match.group(1)}</a>'
-
def replace_internal(match: re.Match) -> str:
path = match.group(1)
return f'<a href="{path}">{path}</a>'
@@ -242,13 +254,14 @@ class Twutr(App):
"""
Shows the user's timeline, combining their messages and those from followed users.
"""
- if not self.request.session.get('logged_in'):
- return self._redirect('/public')
- user_id = self.request.session.get('user_id')
+ if not self.request.session.get("logged_in"):
+ return self._redirect("/public")
+ user_id = self.request.session.get("user_id")
messages = await get_all_messages_for_user_and_following(user_id)
messages = sort_messages_by_timestamp(messages, timestamp_index=2)
- return await self._render_template('timeline.html', messages=messages, session=self.request.session)
-
+ return await self._render_template(
+ "timeline.html", messages=messages, session=self.request.session
+ )
async def public(self) -> Any:
"""
@@ -256,15 +269,16 @@ class Twutr(App):
"""
messages = await get_most_recent_messages(user_collection, limit=200)
messages = sort_messages_by_timestamp(messages, timestamp_index=2)
- return await self._render_template('public.html', messages=messages, session=self.request.session)
-
+ return await self._render_template(
+ "public.html", messages=messages, session=self.request.session
+ )
async def user(self, username: str) -> Any:
"""
Displays a specific user's messages and profile information.
"""
- logged_in = self.request.session.get('logged_in')
- current_user = self.request.session.get('user_id')
+ logged_in = self.request.session.get("logged_in")
+ current_user = self.request.session.get("user_id")
username = escape(username)
if not logged_in or not current_user:
@@ -273,49 +287,46 @@ class Twutr(App):
following = None # Viewing own profile.
else:
current_user_data = await get_user_data(current_user)
- following = username in current_user_data.get('following', [])
+ following = username in current_user_data.get("following", [])
user_data = await get_user_data(username)
if user_data:
- messages = user_data.get('messages', [])
+ messages = user_data.get("messages", [])
messages = sort_messages_by_timestamp(messages, timestamp_index=1)
- followers = user_data.get('followers', [])
- following_count = len(user_data.get('following', []))
+ followers = user_data.get("followers", [])
+ following_count = len(user_data.get("following", []))
return await self._render_template(
- 'user.html',
+ "user.html",
messages=messages,
username=username,
session=self.request.session,
following=following,
followers=followers,
- following_count=following_count
+ following_count=following_count,
)
return "User not found", 404
-
async def follow(self, username: str) -> Any:
"""
Allows the current user to follow another user.
"""
- if not self.request.session.get('logged_in'):
- return self._redirect('/login')
+ if not self.request.session.get("logged_in"):
+ return self._redirect("/login")
username = escape(username)
- current_user = self.request.session.get('user_id')
+ current_user = self.request.session.get("user_id")
if username == current_user:
return "You cannot follow yourself"
await update_follow_relationship(current_user, username, follow=True)
- return self._redirect(f'/user/{username}')
-
+ return self._redirect(f"/user/{username}")
async def unfollow(self, username: str) -> Any:
"""
Allows the current user to unfollow another user.
"""
- if not self.request.session.get('logged_in'):
- return self._redirect('/login')
- current_user = self.request.session.get('user_id')
+ if not self.request.session.get("logged_in"):
+ return self._redirect("/login")
+ current_user = self.request.session.get("user_id")
await update_follow_relationship(current_user, escape(username), follow=False)
- return self._redirect(f'/user/{username}')
-
+ return self._redirect(f"/user/{username}")
async def list_followers(self, username: str) -> Any:
"""
@@ -325,15 +336,14 @@ class Twutr(App):
user_data = await get_user_data(username)
if not user_data:
return "User not found", 404
- followers = user_data.get('followers', [])
+ followers = user_data.get("followers", [])
return await self._render_template(
- 'list_followers.html',
+ "list_followers.html",
username=username,
followers=followers,
- session=self.request.session
+ session=self.request.session,
)
-
async def list_following(self, username: str) -> Any:
"""
Displays the list of users that a specified user is following.
@@ -342,108 +352,112 @@ class Twutr(App):
user_data = await get_user_data(username)
if not user_data:
return "User not found", 404
- following = user_data.get('following', [])
+ following = user_data.get("following", [])
return await self._render_template(
- 'list_following.html',
+ "list_following.html",
username=username,
following=following,
- session=self.request.session
+ session=self.request.session,
)
-
async def add_message(self) -> Any:
"""
Registers a new message for the logged-in user and processes custom syntax.
"""
- if not self.request.session.get('logged_in'):
- return self._redirect('/login')
- if self.request.method == 'POST':
- message = self.request.body_params.get('message', [''])[0]
+ if not self.request.session.get("logged_in"):
+ return self._redirect("/login")
+ if self.request.method == "POST":
+ message = self.request.body_params.get("message", [""])[0]
sanitized_message = convert_custom_syntax(message)
if not sanitized_message.strip():
return await self._render_template(
- 'timeline.html',
+ "timeline.html",
error="Message cannot be empty",
- session=self.request.session
+ session=self.request.session,
)
- time_stamp = datetime.utcnow().strftime('%m/%d/%Y %I:%M %p')
+ time_stamp = datetime.utcnow().strftime("%m/%d/%Y %I:%M %p")
message_tuple = (sanitized_message, time_stamp)
- user_data = await get_user_data(self.request.session.get('user_id'))
- user_data.setdefault('messages', []).append(message_tuple)
- await save_user_data(self.request.session.get('user_id'), user_data)
- return self._redirect('/')
-
+ user_data = await get_user_data(self.request.session.get("user_id"))
+ user_data.setdefault("messages", []).append(message_tuple)
+ await save_user_data(self.request.session.get("user_id"), user_data)
+ return self._redirect("/")
async def login(self) -> Any:
"""
Handles user login, verifying credentials with hashed passwords.
"""
- if self.request.session.get('logged_in'):
- return self._redirect('/')
- if self.request.method == 'POST':
- username = escape(self.request.body_params.get('username', [''])[0].strip())
- password = escape(self.request.body_params.get('password', [''])[0].strip())
+ if self.request.session.get("logged_in"):
+ return self._redirect("/")
+ if self.request.method == "POST":
+ username = escape(self.request.body_params.get("username", [""])[0].strip())
+ password = escape(self.request.body_params.get("password", [""])[0].strip())
if not username or not password:
return await self._render_template(
- 'login.html',
+ "login.html",
error="Fields cannot be empty",
- session=self.request.session
+ session=self.request.session,
)
user = await get_user_data(username)
- stored_password = user.get('password', None)
+ stored_password = user.get("password", None)
- if not user or not stored_password or not check_password(password, stored_password):
+ if (
+ not user
+ or not stored_password
+ or not check_password(password, stored_password)
+ ):
return await self._render_template(
- 'login.html',
+ "login.html",
error="Invalid credentials",
- session=self.request.session
+ session=self.request.session,
)
- self.request.session['user_id'] = username
- self.request.session['logged_in'] = True
- return self._redirect('/')
- return await self._render_template('login.html', session=self.request.session)
-
+ self.request.session["user_id"] = username
+ self.request.session["logged_in"] = True
+ return self._redirect("/")
+ return await self._render_template("login.html", session=self.request.session)
async def register(self) -> Any:
"""
Registers a new user account with password hashing.
"""
- if self.request.session.get('logged_in'):
- return self._redirect('/')
- if self.request.method == 'POST':
- username = escape(self.request.body_params.get('username', [''])[0].strip())
- password = escape(self.request.body_params.get('password', [''])[0].strip())
+ if self.request.session.get("logged_in"):
+ return self._redirect("/")
+ if self.request.method == "POST":
+ username = escape(self.request.body_params.get("username", [""])[0].strip())
+ password = escape(self.request.body_params.get("password", [""])[0].strip())
if not username or not password:
return await self._render_template(
- 'login.html',
+ "login.html",
error="Fields cannot be empty",
- session=self.request.session
+ session=self.request.session,
)
existing_user = await get_user_data(username)
if existing_user:
return await self._render_template(
- 'register.html',
+ "register.html",
session=self.request.session,
- error="Username already taken."
+ error="Username already taken.",
)
- await user_collection.insert_one({
- 'username': username,
- 'password': Binary(hash_password(password)),
- 'messages': [],
- 'followers': [],
- 'following': []
- })
- return self._redirect('/login')
- return await self._render_template('register.html', session=self.request.session)
-
+ await user_collection.insert_one(
+ {
+ "username": username,
+ "password": Binary(hash_password(password)),
+ "messages": [],
+ "followers": [],
+ "following": [],
+ }
+ )
+ return self._redirect("/login")
+ return await self._render_template(
+ "register.html", session=self.request.session
+ )
def logout(self) -> Any:
"""
Logs out the current user.
"""
- if self.request.session.get('logged_in'):
- self.request.session.pop('logged_in', None)
- return self._redirect('/public')
+ if self.request.session.get("logged_in"):
+ self.request.session.pop("logged_in", None)
+ return self._redirect("/public")
# ------------------------------------------------------------------------------
diff --git a/examples/url_shortener/main.py b/examples/url_shortener/main.py
index fff0a99..8ededda 100644
--- a/examples/url_shortener/main.py
+++ b/examples/url_shortener/main.py
@@ -4,30 +4,30 @@ URL Shortener using MicroPie and PyMongo. Live at https://erd.sh/
Copyright 2025 Harrison Erd
-Redistribution and use in source and binary forms, with or without
+Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
-1. Redistributions of source code must retain the above copyright notice,
+1. Redistributions of source code must retain the above copyright notice,
this list of conditions and the following disclaimer.
-2. Redistributions in binary form must reproduce the above copyright notice,
-this list of conditions and the following disclaimer in the documentation
+2. Redistributions in binary form must reproduce the above copyright notice,
+this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
-3. Neither the name of the copyright holder nor the names of its
-contributors may be used to endorse or promote products derived from this
+3. Neither the name of the copyright holder nor the names of its
+contributors may be used to endorse or promote products derived from this
software without specific prior written permission.
-THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
-IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
-THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
-PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
-CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
-EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
-PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
-OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
-WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
-OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
+IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
+THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
+CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
+WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
+OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
@@ -114,7 +114,6 @@ def _parse_hide_stats_on_expire(value) -> bool | None:
class Shorty(App):
-
async def index(self, url_str: str | None = None):
if url_str:
if self.request.method == "POST":
@@ -137,13 +136,15 @@ class Shorty(App):
if not exists:
break
- await urls.insert_one({
- "_id": short_id,
- "url": url_str,
- "clicks": 0,
- "created_at": datetime.utcnow(),
- "last_clicked_at": None,
- })
+ await urls.insert_one(
+ {
+ "_id": short_id,
+ "url": url_str,
+ "clicks": 0,
+ "created_at": datetime.utcnow(),
+ "last_clicked_at": None,
+ }
+ )
return await self._render_template(
"success.html",
@@ -226,7 +227,6 @@ class Shorty(App):
class ApiApp(App):
-
async def index(self):
return await self._render_template("api.html")
@@ -244,11 +244,15 @@ class ApiApp(App):
return 400, {"error": "Invalid URL"}
expires_in = _parse_expires_in(data.get("expires_in"))
- expires_at = (datetime.utcnow() + timedelta(seconds=expires_in)) if expires_in else None
+ expires_at = (
+ (datetime.utcnow() + timedelta(seconds=expires_in)) if expires_in else None
+ )
max_clicks = _parse_max_clicks(data.get("max_clicks"))
- hide_stats_on_expire = _parse_hide_stats_on_expire(data.get("hide_stats_on_expire"))
+ hide_stats_on_expire = _parse_hide_stats_on_expire(
+ data.get("hide_stats_on_expire")
+ )
while True:
short_id = _generate_id()
@@ -256,17 +260,23 @@ class ApiApp(App):
if not exists:
break
- await urls.insert_one({
- "_id": short_id,
- "url": url_str,
- "clicks": 0,
- "created_at": datetime.utcnow(),
- "last_clicked_at": None,
- # API-only controls:
- **({"expires_at": expires_at} if expires_at else {}),
- **({"max_clicks": max_clicks} if max_clicks else {}),
- **({"hide_stats_on_expire": hide_stats_on_expire} if hide_stats_on_expire is not None else {}),
- })
+ await urls.insert_one(
+ {
+ "_id": short_id,
+ "url": url_str,
+ "clicks": 0,
+ "created_at": datetime.utcnow(),
+ "last_clicked_at": None,
+ # API-only controls:
+ **({"expires_at": expires_at} if expires_at else {}),
+ **({"max_clicks": max_clicks} if max_clicks else {}),
+ **(
+ {"hide_stats_on_expire": hide_stats_on_expire}
+ if hide_stats_on_expire is not None
+ else {}
+ ),
+ }
+ )
return {
"status": "success",
@@ -315,10 +325,16 @@ class ApiApp(App):
"long_url": doc.get("url"),
"clicks": int(doc.get("clicks", 0)),
"created_at": created_at.isoformat() + "Z" if created_at else None,
- "last_clicked_at": last_clicked_at.isoformat() + "Z" if last_clicked_at else None,
- "expires_at": expires_at.isoformat() + "Z" if isinstance(expires_at, datetime) else None,
+ "last_clicked_at": last_clicked_at.isoformat() + "Z"
+ if last_clicked_at
+ else None,
+ "expires_at": expires_at.isoformat() + "Z"
+ if isinstance(expires_at, datetime)
+ else None,
"max_clicks": int(max_clicks) if isinstance(max_clicks, int) else None,
- "hide_stats_on_expire": hide_stats_on_expire if isinstance(hide_stats_on_expire, bool) else None,
+ "hide_stats_on_expire": hide_stats_on_expire
+ if isinstance(hide_stats_on_expire, bool)
+ else None,
}
@@ -357,4 +373,3 @@ app.middlewares.append(
subapp=ApiApp(),
)
)
-
diff --git a/examples/url_shortener/middlewares/csrf.py b/examples/url_shortener/middlewares/csrf.py
index 0c17ebb..5cbe29e 100644
--- a/examples/url_shortener/middlewares/csrf.py
+++ b/examples/url_shortener/middlewares/csrf.py
@@ -130,9 +130,15 @@ class CSRFMiddleware(HttpMiddleware):
try:
data = self.serializer.loads(submitted, max_age=self.max_age)
except SignatureExpired:
- return {"status_code": 403, "body": "<h1>Expired CSRF token, please reload the page and try again.</h1>"}
+ return {
+ "status_code": 403,
+ "body": "<h1>Expired CSRF token, please reload the page and try again.</h1>",
+ }
except BadSignature:
- return {"status_code": 403, "body": "<h1>Invalid CSRF token signature, please reload the page and try again.</h1>"}
+ return {
+ "status_code": 403,
+ "body": "<h1>Invalid CSRF token signature, please reload the page and try again.</h1>",
+ }
sid = self._get_session_id(request)
token_sid = data.get("sid")
@@ -157,4 +163,3 @@ class CSRFMiddleware(HttpMiddleware):
if to_emit:
extra_headers.append(("X-CSRF-Token", to_emit))
return None
-
diff --git a/examples/url_shortener/middlewares/rate_limit.py b/examples/url_shortener/middlewares/rate_limit.py
index 4fd0fa6..7eecd0e 100644
--- a/examples/url_shortener/middlewares/rate_limit.py
+++ b/examples/url_shortener/middlewares/rate_limit.py
@@ -12,6 +12,7 @@ from micropie import HttpMiddleware
# IP helpers
# ---------------------------------------------------------------------------
+
def _valid_ip(value: str | None) -> str | None:
"""Parse and normalize an IP string, returning canonical string form or None."""
try:
@@ -72,6 +73,7 @@ def _is_cloudflare_socket_ip(socket_ip: str | None) -> bool:
# Middleware
# ---------------------------------------------------------------------------
+
class MongoRateLimitMiddleware(HttpMiddleware):
"""
Global MongoDB-based rate limiter with Cloudflare anti-spoofing.
@@ -188,7 +190,11 @@ class MongoRateLimitMiddleware(HttpMiddleware):
def _key(self, client_ip: str, request) -> str:
if not self.bucket_by_route:
return client_ip
- method = (getattr(request, "method", None) or (request.scope or {}).get("method") or "GET").upper()
+ method = (
+ getattr(request, "method", None)
+ or (request.scope or {}).get("method")
+ or "GET"
+ ).upper()
path = (request.scope or {}).get("path") or "/"
return f"{client_ip}|{method}|{path}"
@@ -225,11 +231,12 @@ class MongoRateLimitMiddleware(HttpMiddleware):
"violations": {"$ifNull": ["$violations", 0]},
"blocked_until": {"$ifNull": ["$blocked_until", None]},
"permanent_blocked": {"$ifNull": ["$permanent_blocked", False]},
- "permanent_blocked_at": {"$ifNull": ["$permanent_blocked_at", None]},
+ "permanent_blocked_at": {
+ "$ifNull": ["$permanent_blocked_at", None]
+ },
"violation_events": {"$ifNull": ["$violation_events", []]},
}
},
-
# 2) Prune old violation events
{
"$set": {
@@ -242,7 +249,6 @@ class MongoRateLimitMiddleware(HttpMiddleware):
}
}
},
-
# 3) Are we currently blocked?
{
"$set": {
@@ -259,7 +265,6 @@ class MongoRateLimitMiddleware(HttpMiddleware):
}
}
},
-
# 4) Update window/count (only if not blocked)
{
"$set": {
@@ -285,12 +290,17 @@ class MongoRateLimitMiddleware(HttpMiddleware):
"$cond": [
"$_blocked_now",
"$count",
- {"$cond": ["$_window_expired", 1, {"$add": ["$count", 1]}]},
+ {
+ "$cond": [
+ "$_window_expired",
+ 1,
+ {"$add": ["$count", 1]},
+ ]
+ },
]
},
}
},
-
# 5) Over limit?
{
"$set": {
@@ -302,12 +312,15 @@ class MongoRateLimitMiddleware(HttpMiddleware):
}
}
},
-
# 6) Record violation if over limit
{
"$set": {
"violations": {
- "$cond": ["$_over_limit", {"$add": ["$violations", 1]}, "$violations"]
+ "$cond": [
+ "$_over_limit",
+ {"$add": ["$violations", 1]},
+ "$violations",
+ ]
},
"violation_events": {
"$cond": [
@@ -318,7 +331,6 @@ class MongoRateLimitMiddleware(HttpMiddleware):
},
}
},
-
# 7) Temporary block escalation
{
"$set": {
@@ -327,7 +339,12 @@ class MongoRateLimitMiddleware(HttpMiddleware):
{
"$and": [
"$_over_limit",
- {"$gte": ["$violations", self.BLOCK_AFTER_VIOLATIONS]},
+ {
+ "$gte": [
+ "$violations",
+ self.BLOCK_AFTER_VIOLATIONS,
+ ]
+ },
]
},
now + timedelta(seconds=self.BLOCK_FOR_SECONDS),
@@ -336,7 +353,6 @@ class MongoRateLimitMiddleware(HttpMiddleware):
}
}
},
-
# 8) Permanent block escalation
{"$set": {"_events_24h": {"$size": "$violation_events"}}},
{
@@ -346,7 +362,12 @@ class MongoRateLimitMiddleware(HttpMiddleware):
{
"$and": [
"$_over_limit",
- {"$gte": ["$_events_24h", self.PERMA_BLOCK_AFTER]},
+ {
+ "$gte": [
+ "$_events_24h",
+ self.PERMA_BLOCK_AFTER,
+ ]
+ },
]
},
True,
@@ -358,7 +379,12 @@ class MongoRateLimitMiddleware(HttpMiddleware):
{
"$and": [
"$_over_limit",
- {"$gte": ["$_events_24h", self.PERMA_BLOCK_AFTER]},
+ {
+ "$gte": [
+ "$_events_24h",
+ self.PERMA_BLOCK_AFTER,
+ ]
+ },
{"$eq": ["$permanent_blocked_at", None]},
]
},
@@ -368,9 +394,15 @@ class MongoRateLimitMiddleware(HttpMiddleware):
},
}
},
-
# 9) Cleanup temp fields
- {"$unset": ["_blocked_now", "_window_expired", "_over_limit", "_events_24h"]},
+ {
+ "$unset": [
+ "_blocked_now",
+ "_window_expired",
+ "_over_limit",
+ "_events_24h",
+ ]
+ },
],
upsert=True,
return_document=ReturnDocument.AFTER,
@@ -407,4 +439,3 @@ class MongoRateLimitMiddleware(HttpMiddleware):
async def after_request(self, request, status_code, response_body, extra_headers):
return None
-
diff --git a/examples/url_shortener/middlewares/sub_app.py b/examples/url_shortener/middlewares/sub_app.py
index d7c1008..f0c1658 100644
--- a/examples/url_shortener/middlewares/sub_app.py
+++ b/examples/url_shortener/middlewares/sub_app.py
@@ -1,5 +1,6 @@
from micropie import HttpMiddleware
+
class SubAppMiddleware(HttpMiddleware):
def __init__(self, mount_path: str, subapp):
self.mount_path = mount_path.lstrip("/")
@@ -9,10 +10,9 @@ class SubAppMiddleware(HttpMiddleware):
path = request.scope["path"].lstrip("/")
if path.startswith(self.mount_path):
request._subapp = self.subapp
- request._subapp_path = path[len(self.mount_path):].lstrip("/") or "/"
+ request._subapp_path = path[len(self.mount_path) :].lstrip("/") or "/"
request._subapp_mount_path = self.mount_path
return None
async def after_request(self, request, status_code, response_body, extra_headers):
return None
-
diff --git a/examples/url_shortener/sessions/mongo_session.py b/examples/url_shortener/sessions/mongo_session.py
index a97e863..8504461 100644
--- a/examples/url_shortener/sessions/mongo_session.py
+++ b/examples/url_shortener/sessions/mongo_session.py
@@ -91,4 +91,3 @@ class MkvSessionBackend(SessionBackend):
}
await self.store.set(key, payload)
-
diff --git a/examples/websockets/app.py b/examples/websockets/app.py
index 26c038f..6ac5580 100644
--- a/examples/websockets/app.py
+++ b/examples/websockets/app.py
@@ -1,5 +1,6 @@
from micropie import App, ConnectionClosed
+
class MyApp(App):
async def chat(self):
"""HTTP handler for GET /chat"""
@@ -18,4 +19,5 @@ class MyApp(App):
except ConnectionClosed:
break
+
app = MyApp()
diff --git a/micropie.py b/micropie.py
index 3b330d0..52ecba3 100644
--- a/micropie.py
+++ b/micropie.py
@@ -28,12 +28,14 @@ except ImportError:
try:
from jinja2 import Environment, FileSystemLoader, select_autoescape
+
JINJA_INSTALLED = True
except ImportError:
JINJA_INSTALLED = False
try:
from multipart import PushMultipartParser, MultipartSegment
+
MULTIPART_INSTALLED = True
except ImportError:
MULTIPART_INSTALLED = False
@@ -44,6 +46,7 @@ except ImportError:
# -----------------------------
SESSION_TIMEOUT: int = 8 * 3600 # Default 8 hours
+
class SessionBackend(ABC):
@abstractmethod
async def load(self, session_id: str) -> Dict[str, Any]:
@@ -67,6 +70,7 @@ class SessionBackend(ABC):
"""
pass
+
class InMemorySessionBackend(SessionBackend):
def __init__(self):
self.sessions: Dict[str, Dict[str, Any]] = {}
@@ -76,8 +80,7 @@ class InMemorySessionBackend(SessionBackend):
"""Remove expired sessions based on SESSION_TIMEOUT."""
now = time.time()
expired = [
- sid for sid, ts in self.last_access.items()
- if now - ts >= SESSION_TIMEOUT
+ sid for sid, ts in self.last_access.items() if now - ts >= SESSION_TIMEOUT
]
for sid in expired:
self.sessions.pop(sid, None)
@@ -107,8 +110,10 @@ class InMemorySessionBackend(SessionBackend):
# -----------------------------
current_request: contextvars.ContextVar[Any] = contextvars.ContextVar("current_request")
+
class Request:
"""Represents an HTTP request in the MicroPie framework."""
+
def __init__(self, scope: Dict[str, Any]) -> None:
"""
Initialize a new Request instance.
@@ -125,20 +130,29 @@ class Request:
self.session: Dict[str, Any] = scope.get("session", {})
self.files: Dict[str, Any] = scope.get("files", {})
self.headers: Dict[str, str] = {
- k.decode("utf-8", errors="replace").lower(): v.decode("utf-8", errors="replace")
+ k.decode("utf-8", errors="replace").lower(): v.decode(
+ "utf-8", errors="replace"
+ )
for k, v in scope.get("headers", [])
}
self.body_parsed: bool = scope.get("body_parsed", False)
+
class WebSocketRequest(Request):
"""Represents a WebSocket request in the MicroPie framework."""
+
def __init__(self, scope: Dict[str, Any]) -> None:
super().__init__(scope)
+
class WebSocket:
"""Manages WebSocket communication in the MicroPie framework."""
- def __init__(self, receive: Callable[[], Awaitable[Dict[str, Any]]],
- send: Callable[[Dict[str, Any]], Awaitable[None]]) -> None:
+
+ def __init__(
+ self,
+ receive: Callable[[], Awaitable[Dict[str, Any]]],
+ send: Callable[[Dict[str, Any]], Awaitable[None]],
+ ) -> None:
"""
Initialize a WebSocket instance.
@@ -151,7 +165,9 @@ class WebSocket:
self.accepted = False
self.session_id: Optional[str] = None
- async def accept(self, subprotocol: Optional[str] = None, session_id: Optional[str] = None) -> None:
+ async def accept(
+ self, subprotocol: Optional[str] = None, session_id: Optional[str] = None
+ ) -> None:
"""
Accept the WebSocket connection.
@@ -167,13 +183,22 @@ class WebSocket:
raise ValueError(f"Expected websocket.connect, got {message['type']}")
headers = []
if session_id:
- headers.append(("Set-Cookie", f"session_id={session_id}; Path=/; SameSite=Lax; HttpOnly; Secure;"))
+ headers.append(
+ (
+ "Set-Cookie",
+ f"session_id={session_id}; Path=/; SameSite=Lax; HttpOnly; Secure;",
+ )
+ )
self.session_id = session_id
- await self.send({
- "type": "websocket.accept",
- "subprotocol": subprotocol,
- "headers": [(k.encode("latin-1"), v.encode("latin-1")) for k, v in headers]
- })
+ await self.send(
+ {
+ "type": "websocket.accept",
+ "subprotocol": subprotocol,
+ "headers": [
+ (k.encode("latin-1"), v.encode("latin-1")) for k, v in headers
+ ],
+ }
+ )
self.accepted = True
async def receive_text(self) -> str:
@@ -189,7 +214,9 @@ class WebSocket:
"""
message = await self.receive()
if message["type"] == "websocket.receive":
- return message.get("text", message.get("bytes", b"").decode("utf-8", "ignore"))
+ return message.get(
+ "text", message.get("bytes", b"").decode("utf-8", "ignore")
+ )
elif message["type"] == "websocket.disconnect":
raise ConnectionClosed()
raise ValueError(f"Unexpected message type: {message['type']}")
@@ -224,10 +251,7 @@ class WebSocket:
"""
if not self.accepted:
raise RuntimeError("WebSocket connection not accepted")
- await self.send({
- "type": "websocket.send",
- "text": data
- })
+ await self.send({"type": "websocket.send", "text": data})
async def send_bytes(self, data: bytes) -> None:
"""
@@ -241,10 +265,7 @@ class WebSocket:
"""
if not self.accepted:
raise RuntimeError("WebSocket connection not accepted")
- await self.send({
- "type": "websocket.send",
- "bytes": data
- })
+ await self.send({"type": "websocket.send", "bytes": data})
async def close(self, code: int = 1000, reason: Optional[str] = None) -> None:
"""
@@ -255,15 +276,15 @@ class WebSocket:
reason: Optional reason for closure.
"""
if self.accepted:
- await self.send({
- "type": "websocket.close",
- "code": code,
- "reason": reason or ""
- })
+ await self.send(
+ {"type": "websocket.close", "code": code, "reason": reason or ""}
+ )
self.accepted = False
+
class ConnectionClosed(Exception):
"""Raised when a WebSocket connection is closed."""
+
pass
@@ -274,14 +295,15 @@ class HttpMiddleware(ABC):
"""
Pluggable middleware class that allows hooking into the HTTP request lifecycle.
"""
+
@abstractmethod
async def before_request(self, request: Request) -> Optional[Dict]:
"""
Called before the HTTP request is processed.
-
+
Args:
request: The Request object.
-
+
Returns:
Optional dictionary with response details (status_code, body, headers) to short-circuit the request,
or None to continue processing.
@@ -294,34 +316,36 @@ class HttpMiddleware(ABC):
request: Request,
status_code: int,
response_body: Any,
- extra_headers: List[Tuple[str, str]]
+ extra_headers: List[Tuple[str, str]],
) -> Optional[Dict]:
"""
Called after the HTTP request is processed, but before the final response is sent.
-
+
Args:
request: The Request object.
status_code: The HTTP status code.
response_body: The response body.
extra_headers: List of header tuples.
-
+
Returns:
Optional dictionary with updated response details (status_code, body, headers), or None to use defaults.
"""
pass
+
class WebSocketMiddleware(ABC):
"""
Pluggable middleware class that allows hooking into the WebSocket request lifecycle.
"""
+
@abstractmethod
async def before_websocket(self, request: WebSocketRequest) -> Optional[Dict]:
"""
Called before the WebSocket handler is invoked.
-
+
Args:
request: The WebSocketRequest object.
-
+
Returns:
Optional dictionary with close details (code, reason) to reject the connection,
or None to continue processing.
@@ -332,7 +356,7 @@ class WebSocketMiddleware(ABC):
async def after_websocket(self, request: WebSocketRequest) -> None:
"""
Called after the WebSocket handler completes.
-
+
Args:
request: The WebSocketRequest object.
"""
@@ -349,16 +373,19 @@ class App:
pluggable HTTP middlewares via the 'middlewares' list, WebSocket middlewares via the 'ws_middlewares' list,
and startup/shutdown handlers via 'startup_handlers' and 'shutdown_handlers'.
"""
+
def __init__(self, session_backend: Optional[SessionBackend] = None) -> None:
if JINJA_INSTALLED:
self.env = Environment(
loader=FileSystemLoader("templates"),
autoescape=select_autoescape(["html", "xml"]),
- enable_async=True
+ enable_async=True,
)
else:
self.env = None
- self.session_backend: SessionBackend = session_backend or InMemorySessionBackend()
+ self.session_backend: SessionBackend = (
+ session_backend or InMemorySessionBackend()
+ )
self.middlewares: List[HttpMiddleware] = []
self.ws_middlewares: List[WebSocketMiddleware] = []
self.startup_handlers: List[Callable[[], Awaitable[None]]] = []
@@ -378,7 +405,7 @@ class App:
self,
scope: Dict[str, Any],
receive: Callable[[], Awaitable[Dict[str, Any]]],
- send: Callable[[Dict[str, Any]], Awaitable[None]]
+ send: Callable[[Dict[str, Any]], Awaitable[None]],
) -> None:
"""
ASGI callable interface for the server.
@@ -400,7 +427,7 @@ class App:
async def _asgi_app_lifespan(
self,
receive: Callable[[], Awaitable[Dict[str, Any]]],
- send: Callable[[Dict[str, Any]], Awaitable[None]]
+ send: Callable[[Dict[str, Any]], Awaitable[None]],
) -> None:
"""
Handle ASGI lifespan events for startup and shutdown.
@@ -436,7 +463,7 @@ class App:
self,
scope: Dict[str, Any],
receive: Callable[[], Awaitable[Dict[str, Any]]],
- send: Callable[[Dict[str, Any]], Awaitable[None]]
+ send: Callable[[Dict[str, Any]], Awaitable[None]],
) -> None:
"""
ASGI application entry point for handling HTTP requests.
@@ -446,7 +473,9 @@ class App:
status_code: int = 200
response_body: Any = ""
extra_headers: List[Tuple[str, str]] = []
- parse_task: Optional[asyncio.Task] = None # background multipart task if started
+ parse_task: Optional[asyncio.Task] = (
+ None # background multipart task if started
+ )
async def _cancel_parse_task():
if parse_task is not None and not parse_task.done():
@@ -456,7 +485,9 @@ class App:
except asyncio.CancelledError:
pass
- async def _early_exit(code: int, body: Any, headers: Optional[List[Tuple[str, str]]] = None):
+ async def _early_exit(
+ code: int, body: Any, headers: Optional[List[Tuple[str, str]]] = None
+ ):
await _cancel_parse_task()
await self._send_response(send, code, body, headers or [])
return
@@ -479,13 +510,22 @@ class App:
try:
# Parse query/cookies/session
- request.query_params = parse_qs(scope.get("query_string", b"").decode("utf-8", "ignore"))
+ request.query_params = parse_qs(
+ scope.get("query_string", b"").decode("utf-8", "ignore")
+ )
cookies = self._parse_cookies(request.headers.get("cookie", ""))
- request.session = scope.get("session", await self.session_backend.load(cookies.get("session_id", "")) or {})
+ request.session = scope.get(
+ "session",
+ await self.session_backend.load(cookies.get("session_id", "")) or {},
+ )
content_type = request.headers.get("content-type", "")
# Body parsing setup
- if request.method in ("POST", "PUT", "PATCH") and not request.body_parsed and not request.body_params:
+ if (
+ request.method in ("POST", "PUT", "PATCH")
+ and not request.body_parsed
+ and not request.body_params
+ ):
if "multipart/form-data" in content_type:
if not MULTIPART_INSTALLED:
print("For multipart form data support install 'multipart'.")
@@ -515,14 +555,18 @@ class App:
if not msg.get("more_body"):
break
except asyncio.TimeoutError:
- await _early_exit(408, "408 Request Timeout: Failed to receive body")
+ await _early_exit(
+ 408, "408 Request Timeout: Failed to receive body"
+ )
return
decoded_body = body_data.decode("utf-8", "ignore")
if "application/json" in content_type:
try:
request.get_json = json.loads(decoded_body)
if isinstance(request.get_json, dict):
- request.body_params = {k: [str(v)] for k, v in request.get_json.items()}
+ request.body_params = {
+ k: [str(v)] for k, v in request.get_json.items()
+ }
except Exception:
await _early_exit(400, "400 Bad Request: Bad JSON")
return
@@ -591,9 +635,14 @@ class App:
if handler == getattr(self, "index", None) and path and path != "index":
sig = inspect.signature(handler)
accepts_params = any(
- param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD,
- inspect.Parameter.VAR_POSITIONAL)
- for param in sig.parameters.values() if param.name != "self"
+ param.kind
+ in (
+ inspect.Parameter.POSITIONAL_ONLY,
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
+ inspect.Parameter.VAR_POSITIONAL,
+ )
+ for param in sig.parameters.values()
+ if param.name != "self"
)
if not accepts_params:
await _early_exit(404, "404 Not Found")
@@ -625,7 +674,10 @@ class App:
elif "multipart/form-data" in content_type:
param_value = await _await_file_param(param.name)
if param_value is None and param.default is param.empty:
- await _early_exit(400, f"400 Bad Request: Missing required parameter '{param.name}'")
+ await _early_exit(
+ 400,
+ f"400 Bad Request: Missing required parameter '{param.name}'",
+ )
return
if param_value is None:
param_value = param.default
@@ -634,14 +686,21 @@ class App:
elif param.default is not param.empty:
param_value = param.default
else:
- await _early_exit(400, f"400 Bad Request: Missing required parameter '{param.name}'")
+ await _early_exit(
+ 400,
+ f"400 Bad Request: Missing required parameter '{param.name}'",
+ )
return
func_args.append(param_value)
# Execute handler
try:
- result = await handler(*func_args) if inspect.iscoroutinefunction(handler) else handler(*func_args)
+ result = (
+ await handler(*func_args)
+ if inspect.iscoroutinefunction(handler)
+ else handler(*func_args)
+ )
except Exception:
traceback.print_exc()
await _early_exit(500, "500 Internal Server Error")
@@ -667,11 +726,13 @@ class App:
# HTTP middlewares
for mw in self.middlewares:
- if result := await mw.after_request(request, status_code, response_body, extra_headers):
+ if result := await mw.after_request(
+ request, status_code, response_body, extra_headers
+ ):
status_code, response_body, extra_headers = (
result.get("status_code", status_code),
result.get("body", response_body),
- result.get("headers", extra_headers)
+ result.get("headers", extra_headers),
)
# Session persistence after middlewares so they can mutate request.session
@@ -681,16 +742,19 @@ class App:
# New or updated session
if not session_id:
session_id = str(uuid.uuid4())
- extra_headers.append((
- "Set-Cookie",
- f"session_id={session_id}; Path=/; SameSite=Lax; HttpOnly; Secure;"
- ))
- await self.session_backend.save(session_id, request.session, SESSION_TIMEOUT)
+ extra_headers.append(
+ (
+ "Set-Cookie",
+ f"session_id={session_id}; Path=/; SameSite=Lax; HttpOnly; Secure;",
+ )
+ )
+ await self.session_backend.save(
+ session_id, request.session, SESSION_TIMEOUT
+ )
elif session_id:
# Empty session and existing cookie -> treat as logout/delete
await self.session_backend.save(session_id, {}, 0)
-
# Handle async generators (e.g., SSE)
if hasattr(response_body, "__aiter__"):
sanitized_headers: List[Tuple[str, str]] = []
@@ -700,12 +764,19 @@ class App:
continue
sanitized_headers.append((k, v))
if not any(h[0].lower() == "content-type" for h in sanitized_headers):
- sanitized_headers.append(("Content-Type", "text/html; charset=utf-8"))
- await send({
- "type": "http.response.start",
- "status": status_code,
- "headers": [(k.encode("latin-1"), v.encode("latin-1")) for k, v in sanitized_headers],
- })
+ sanitized_headers.append(
+ ("Content-Type", "text/html; charset=utf-8")
+ )
+ await send(
+ {
+ "type": "http.response.start",
+ "status": status_code,
+ "headers": [
+ (k.encode("latin-1"), v.encode("latin-1"))
+ for k, v in sanitized_headers
+ ],
+ }
+ )
gen = response_body
@@ -714,16 +785,20 @@ class App:
async for chunk in gen:
if isinstance(chunk, str):
chunk = chunk.encode("utf-8")
- await send({
+ await send(
+ {
+ "type": "http.response.body",
+ "body": chunk,
+ "more_body": True,
+ }
+ )
+ await send(
+ {
"type": "http.response.body",
- "body": chunk,
- "more_body": True
- })
- await send({
- "type": "http.response.body",
- "body": b"",
- "more_body": False
- })
+ "body": b"",
+ "more_body": False,
+ }
+ )
except asyncio.CancelledError:
raise
finally:
@@ -734,7 +809,10 @@ class App:
try:
while True:
msg_task = asyncio.create_task(receive())
- done, _ = await asyncio.wait([streaming_task, msg_task], return_when=asyncio.FIRST_COMPLETED)
+ done, _ = await asyncio.wait(
+ [streaming_task, msg_task],
+ return_when=asyncio.FIRST_COMPLETED,
+ )
if streaming_task in done:
break
if msg_task in done:
@@ -755,16 +833,18 @@ class App:
pass
return
else:
- await self._send_response(send, status_code, response_body, extra_headers)
+ await self._send_response(
+ send, status_code, response_body, extra_headers
+ )
finally:
current_request.reset(token)
-
+
async def _asgi_app_websocket(
self,
scope: Dict[str, Any],
receive: Callable[[], Awaitable[Dict[str, Any]]],
- send: Callable[[Dict[str, Any]], Awaitable[None]]
+ send: Callable[[Dict[str, Any]], Awaitable[None]],
) -> None:
"""
ASGI application entry point for handling WebSocket requests.
@@ -778,14 +858,21 @@ class App:
token = current_request.set(request)
try:
# Parse request details (query params, cookies, session)
- request.query_params = parse_qs(scope.get("query_string", b"").decode("utf-8", "ignore"))
+ request.query_params = parse_qs(
+ scope.get("query_string", b"").decode("utf-8", "ignore")
+ )
cookies = self._parse_cookies(request.headers.get("cookie", ""))
- request.session = await self.session_backend.load(cookies.get("session_id", "")) or {}
+ request.session = (
+ await self.session_backend.load(cookies.get("session_id", "")) or {}
+ )
# Run WebSocket middleware before_websocket
for mw in self.ws_middlewares:
if result := await mw.before_websocket(request):
- code, reason = result.get("code", 1008), result.get("reason", "Middleware rejected")
+ code, reason = (
+ result.get("code", 1008),
+ result.get("reason", "Middleware rejected"),
+ )
await self._send_websocket_close(send, code, reason)
return
@@ -794,7 +881,9 @@ class App:
parts: List[str] = path.split("/") if path else []
func_name: str = parts[0] if parts else "ws_index"
if func_name.startswith("_"):
- await self._send_websocket_close(send, 1008, "Private handler not allowed")
+ await self._send_websocket_close(
+ send, 1008, "Private handler not allowed"
+ )
return
# Map WebSocket handler (e.g., /chat -> ws_chat)
@@ -804,7 +893,9 @@ class App:
request.path_params = parts[1:] if len(parts) > 1 else []
handler = getattr(self, handler_name, None)
if not handler:
- await self._send_websocket_close(send, 1008, "No matching WebSocket route")
+ await self._send_websocket_close(
+ send, 1008, "No matching WebSocket route"
+ )
return
# Build function arguments
@@ -830,7 +921,9 @@ class App:
elif param.default is not param.empty:
param_value = param.default
else:
- await self._send_websocket_close(send, 1008, f"Missing required parameter '{param.name}'")
+ await self._send_websocket_close(
+ send, 1008, f"Missing required parameter '{param.name}'"
+ )
return
func_args.append(param_value)
@@ -854,7 +947,9 @@ class App:
# Save / clear session after middlewares
if request.session:
- await self.session_backend.save(ws.session_id, request.session, SESSION_TIMEOUT)
+ await self.session_backend.save(
+ ws.session_id, request.session, SESSION_TIMEOUT
+ )
else:
# Treat empty session as logout/delete
await self.session_backend.save(ws.session_id, {}, 0)
@@ -887,7 +982,7 @@ class App:
boundary: bytes,
request: "Request",
*,
- file_queue_maxsize: int = 2048
+ file_queue_maxsize: int = 2048,
) -> None:
"""
Parse multipart directly from ASGI receive() and populate
@@ -930,10 +1025,13 @@ class App:
if current_filename:
# File field → bounded queue enforces backpressure
- current_queue = asyncio.Queue(maxsize=file_queue_maxsize)
+ current_queue = asyncio.Queue(
+ maxsize=file_queue_maxsize
+ )
request.files[current_field_name] = {
"filename": current_filename,
- "content_type": current_content_type or "application/octet-stream",
+ "content_type": current_content_type
+ or "application/octet-stream",
"content": current_queue,
}
else:
@@ -953,7 +1051,9 @@ class App:
current_queue = None
else:
if form_value and current_field_name:
- request.body_params[current_field_name].append(form_value)
+ request.body_params[current_field_name].append(
+ form_value
+ )
form_value = ""
if not msg.get("more_body"):
@@ -970,7 +1070,7 @@ class App:
send: Callable[[Dict[str, Any]], Awaitable[None]],
status_code: int,
body: Any,
- extra_headers: Optional[List[Tuple[str, str]]] = None
+ extra_headers: Optional[List[Tuple[str, str]]] = None,
) -> None:
"""
Send an HTTP response using the ASGI send callable.
@@ -992,47 +1092,42 @@ class App:
sanitized_headers.append((k, v))
if not any(h[0].lower() == "content-type" for h in sanitized_headers):
sanitized_headers.append(("Content-Type", "text/html; charset=utf-8"))
- await send({
- "type": "http.response.start",
- "status": status_code,
- "headers": [(k.encode("latin-1"), v.encode("latin-1")) for k, v in sanitized_headers],
- })
+ await send(
+ {
+ "type": "http.response.start",
+ "status": status_code,
+ "headers": [
+ (k.encode("latin-1"), v.encode("latin-1"))
+ for k, v in sanitized_headers
+ ],
+ }
+ )
# Handle async generators (non-SSE cases; SSE is handled in _asgi_app_http)
if hasattr(body, "__aiter__"):
async for chunk in body:
if isinstance(chunk, str):
chunk = chunk.encode("utf-8")
- await send({
- "type": "http.response.body",
- "body": chunk,
- "more_body": True
- })
+ await send(
+ {"type": "http.response.body", "body": chunk, "more_body": True}
+ )
await send({"type": "http.response.body", "body": b"", "more_body": False})
return
if hasattr(body, "__iter__") and not isinstance(body, (bytes, str)):
for chunk in body:
if isinstance(chunk, str):
chunk = chunk.encode("utf-8")
- await send({
- "type": "http.response.body",
- "body": chunk,
- "more_body": True
- })
+ await send(
+ {"type": "http.response.body", "body": chunk, "more_body": True}
+ )
await send({"type": "http.response.body", "body": b"", "more_body": False})
return
- response_body = (body if isinstance(body, bytes)
- else str(body).encode("utf-8"))
- await send({
- "type": "http.response.body",
- "body": response_body,
- "more_body": False
- })
+ response_body = body if isinstance(body, bytes) else str(body).encode("utf-8")
+ await send(
+ {"type": "http.response.body", "body": response_body, "more_body": False}
+ )
async def _send_websocket_close(
- self,
- send: Callable[[Dict[str, Any]], Awaitable[None]],
- code: int,
- reason: str
+ self, send: Callable[[Dict[str, Any]], Awaitable[None]], code: int, reason: str
) -> None:
"""
Send a WebSocket close message.
@@ -1042,11 +1137,7 @@ class App:
code: The closure code.
reason: The reason for closure.
"""
- await send({
- "type": "websocket.close",
- "code": code,
- "reason": reason
- })
+ await send({"type": "websocket.close", "code": code, "reason": reason})
def _encode_redirect_url(self, url: str) -> str:
"""
diff --git a/tests.py b/tests.py
index 1c26030..a55e937 100644
--- a/tests.py
+++ b/tests.py
@@ -3,16 +3,32 @@ import unittest
import uuid
from unittest.mock import AsyncMock, patch
from urllib.parse import parse_qs
-from micropie import App, InMemorySessionBackend, Request, WebSocketRequest, SESSION_TIMEOUT, ConnectionClosed, HttpMiddleware
+from micropie import (
+ App,
+ InMemorySessionBackend,
+ Request,
+ WebSocketRequest,
+ SESSION_TIMEOUT,
+ ConnectionClosed,
+ HttpMiddleware,
+)
+
class MicroPieTestCase(unittest.IsolatedAsyncioTestCase):
"""Base test case for MicroPie tests with common setup."""
-
+
async def asyncSetUp(self):
"""Initialize the App instance for each test."""
self.app = App(session_backend=InMemorySessionBackend())
- def create_mock_scope(self, path="/index", method="GET", headers=None, query_string=b"", scope_type="http"):
+ def create_mock_scope(
+ self,
+ path="/index",
+ method="GET",
+ headers=None,
+ query_string=b"",
+ scope_type="http",
+ ):
"""Create a mock ASGI scope for testing."""
if headers is None:
headers = []
@@ -21,9 +37,10 @@ class MicroPieTestCase(unittest.IsolatedAsyncioTestCase):
"method": method,
"path": path,
"headers": headers,
- "query_string": query_string
+ "query_string": query_string,
}
+
class TestRequest(MicroPieTestCase):
"""Tests for the Request and WebSocketRequest classes."""
@@ -34,13 +51,21 @@ class TestRequest(MicroPieTestCase):
"method": "GET",
"path": "/test",
"headers": [(b"host", b"example.com"), (b"cookie", b"session_id=123")],
- "query_string": b"param1=value1"
+ "query_string": b"param1=value1",
}
request = Request(scope)
- request.query_params = parse_qs(scope.get("query_string", b"").decode("utf-8", "ignore"))
+ request.query_params = parse_qs(
+ scope.get("query_string", b"").decode("utf-8", "ignore")
+ )
self.assertEqual(request.method, "GET", "Request method should be GET")
- self.assertEqual(request.headers["host"], "example.com", "Host header should be set")
- self.assertEqual(request.query_params, {"param1": ["value1"]}, "Query params should be parsed")
+ self.assertEqual(
+ request.headers["host"], "example.com", "Host header should be set"
+ )
+ self.assertEqual(
+ request.query_params,
+ {"param1": ["value1"]},
+ "Query params should be parsed",
+ )
self.assertEqual(request.session, {}, "Session should be empty initially")
async def test_websocket_request_initialization(self):
@@ -49,12 +74,21 @@ class TestRequest(MicroPieTestCase):
"type": "websocket",
"path": "/ws_test",
"headers": [(b"host", b"example.com")],
- "query_string": b"param1=value1"
+ "query_string": b"param1=value1",
}
request = WebSocketRequest(scope)
- request.query_params = parse_qs(scope.get("query_string", b"").decode("utf-8", "ignore"))
- self.assertEqual(request.scope["path"], "/ws_test", "WebSocketRequest path should be set")
- self.assertEqual(request.query_params, {"param1": ["value1"]}, "Query params should be parsed")
+ request.query_params = parse_qs(
+ scope.get("query_string", b"").decode("utf-8", "ignore")
+ )
+ self.assertEqual(
+ request.scope["path"], "/ws_test", "WebSocketRequest path should be set"
+ )
+ self.assertEqual(
+ request.query_params,
+ {"param1": ["value1"]},
+ "Query params should be parsed",
+ )
+
class TestSession(MicroPieTestCase):
"""Tests for session management and cookie parsing."""
@@ -67,7 +101,9 @@ class TestSession(MicroPieTestCase):
await backend.save(session_id, session_data, SESSION_TIMEOUT)
loaded_data = await backend.load(session_id)
- self.assertEqual(loaded_data, session_data, "Loaded session data should match saved data")
+ self.assertEqual(
+ loaded_data, session_data, "Loaded session data should match saved data"
+ )
backend.last_access[session_id] = 0 # Simulate expired session
expired_data = await backend.load(session_id)
@@ -77,15 +113,20 @@ class TestSession(MicroPieTestCase):
"""Test parsing of cookie header."""
cookie_header = "session_id=abc123; theme=dark; user=john"
cookies = self.app._parse_cookies(cookie_header)
- self.assertEqual(cookies, {
- "session_id": "abc123",
- "theme": "dark",
- "user": "john"
- }, "Cookies should be parsed correctly")
- self.assertEqual(self.app._parse_cookies(""), {}, "Empty cookie header should return empty dict")
+ self.assertEqual(
+ cookies,
+ {"session_id": "abc123", "theme": "dark", "user": "john"},
+ "Cookies should be parsed correctly",
+ )
+ self.assertEqual(
+ self.app._parse_cookies(""),
+ {},
+ "Empty cookie header should return empty dict",
+ )
async def test_session_management(self):
"""Test session handling in request processing."""
+
async def set_session(self):
self.request.session["user"] = "test_user"
return 200, "Session set"
@@ -93,7 +134,9 @@ class TestSession(MicroPieTestCase):
setattr(self.app, "set_session", set_session.__get__(self.app, App))
scope = self.create_mock_scope(path="/set_session")
- receive = AsyncMock(return_value={"type": "http.request", "body": b"", "more_body": False})
+ receive = AsyncMock(
+ return_value={"type": "http.request", "body": b"", "more_body": False}
+ )
send = AsyncMock()
await self.app(scope, receive, send)
@@ -101,91 +144,111 @@ class TestSession(MicroPieTestCase):
set_cookie_call = None
for call in send.call_args_list:
args = call[0][0]
- if args["type"] == "http.response.start" and any(h[0] == b"Set-Cookie" for h in args["headers"]):
+ if args["type"] == "http.response.start" and any(
+ h[0] == b"Set-Cookie" for h in args["headers"]
+ ):
set_cookie_call = args
break
self.assertIsNotNone(set_cookie_call, "Set-Cookie header not found")
self.assertTrue(
- any(h[0] == b"Set-Cookie" and b"session_id=" in h[1] for h in set_cookie_call["headers"]),
- "Set-Cookie header with session_id not found"
+ any(
+ h[0] == b"Set-Cookie" and b"session_id=" in h[1]
+ for h in set_cookie_call["headers"]
+ ),
+ "Set-Cookie header with session_id not found",
)
self.assertEqual(set_cookie_call["status"], 200, "Status should be 200")
+
class TestRouting(MicroPieTestCase):
"""Tests for HTTP and WebSocket routing."""
async def test_app_handler(self):
"""Test handling of a simple HTTP request with query parameter."""
+
async def index(self, name="World"):
return 200, f"Hello, {name}!"
setattr(self.app, "index", index.__get__(self.app, App))
scope = self.create_mock_scope(path="/index", query_string=b"name=Test")
- receive = AsyncMock(return_value={"type": "http.request", "body": b"", "more_body": False})
+ receive = AsyncMock(
+ return_value={"type": "http.request", "body": b"", "more_body": False}
+ )
send = AsyncMock()
await self.app(scope, receive, send)
- send.assert_any_call({
- "type": "http.response.start",
- "status": 200,
- "headers": [(b"Content-Type", b"text/html; charset=utf-8")]
- })
- send.assert_any_call({
- "type": "http.response.body",
- "body": b"Hello, Test!",
- "more_body": False
- })
+ send.assert_any_call(
+ {
+ "type": "http.response.start",
+ "status": 200,
+ "headers": [(b"Content-Type", b"text/html; charset=utf-8")],
+ }
+ )
+ send.assert_any_call(
+ {"type": "http.response.body", "body": b"Hello, Test!", "more_body": False}
+ )
async def test_404_response(self):
"""Test 404 response for non-existent route."""
scope = self.create_mock_scope(path="/nonexistent")
- receive = AsyncMock(return_value={"type": "http.request", "body": b"", "more_body": False})
+ receive = AsyncMock(
+ return_value={"type": "http.request", "body": b"", "more_body": False}
+ )
send = AsyncMock()
await self.app(scope, receive, send)
- send.assert_any_call({
- "type": "http.response.start",
- "status": 404,
- "headers": [(b"Content-Type", b"text/html; charset=utf-8")]
- })
- send.assert_any_call({
- "type": "http.response.body",
- "body": b"404 Not Found",
- "more_body": False
- })
+ send.assert_any_call(
+ {
+ "type": "http.response.start",
+ "status": 404,
+ "headers": [(b"Content-Type", b"text/html; charset=utf-8")],
+ }
+ )
+ send.assert_any_call(
+ {"type": "http.response.body", "body": b"404 Not Found", "more_body": False}
+ )
async def test_missing_parameter(self):
"""Test handler with missing required parameter."""
+
async def index(self, required_param):
return "Should not reach here"
setattr(self.app, "index", index.__get__(self.app, App))
scope = self.create_mock_scope(path="/index")
- receive = AsyncMock(return_value={"type": "http.request", "body": b"", "more_body": False})
+ receive = AsyncMock(
+ return_value={"type": "http.request", "body": b"", "more_body": False}
+ )
send = AsyncMock()
await self.app(scope, receive, send)
- send.assert_any_call({
- "type": "http.response.start",
- "status": 400,
- "headers": [(b"Content-Type", b"text/html; charset=utf-8")]
- })
- send.assert_any_call({
- "type": "http.response.body",
- "body": b"400 Bad Request: Missing required parameter 'required_param'",
- "more_body": False
- })
+ send.assert_any_call(
+ {
+ "type": "http.response.start",
+ "status": 400,
+ "headers": [(b"Content-Type", b"text/html; charset=utf-8")],
+ }
+ )
+ send.assert_any_call(
+ {
+ "type": "http.response.body",
+ "body": b"400 Bad Request: Missing required parameter 'required_param'",
+ "more_body": False,
+ }
+ )
+
class TestWebSocket(MicroPieTestCase):
"""Tests for WebSocket handling."""
async def test_websocket_handler(self):
"""Test WebSocket connection and message handling."""
+
async def ws_echo(self, ws):
await ws.accept()
msg = await ws.receive_text()
@@ -195,29 +258,24 @@ class TestWebSocket(MicroPieTestCase):
setattr(self.app, "ws_echo", ws_echo.__get__(self.app, App))
scope = self.create_mock_scope(path="/echo", scope_type="websocket")
- receive = AsyncMock(side_effect=[
- {"type": "websocket.connect"},
- {"type": "websocket.receive", "text": "Hello"},
- {"type": "websocket.disconnect", "code": 1000}
- ])
+ receive = AsyncMock(
+ side_effect=[
+ {"type": "websocket.connect"},
+ {"type": "websocket.receive", "text": "Hello"},
+ {"type": "websocket.disconnect", "code": 1000},
+ ]
+ )
send = AsyncMock()
await self.app(scope, receive, send)
- send.assert_any_call({
- "type": "websocket.accept",
- "subprotocol": None,
- "headers": []
- })
- send.assert_any_call({
- "type": "websocket.send",
- "text": "Echo: Hello"
- })
- send.assert_any_call({
- "type": "websocket.close",
- "code": 1000,
- "reason": "Done"
- })
+ send.assert_any_call(
+ {"type": "websocket.accept", "subprotocol": None, "headers": []}
+ )
+ send.assert_any_call({"type": "websocket.send", "text": "Echo: Hello"})
+ send.assert_any_call(
+ {"type": "websocket.close", "code": 1000, "reason": "Done"}
+ )
async def test_websocket_missing_handler(self):
"""Test WebSocket 1008 response for non-existent route."""
@@ -227,23 +285,34 @@ class TestWebSocket(MicroPieTestCase):
await self.app(scope, receive, send)
- send.assert_any_call({
- "type": "websocket.close",
- "code": 1008,
- "reason": "No matching WebSocket route"
- })
+ send.assert_any_call(
+ {
+ "type": "websocket.close",
+ "code": 1008,
+ "reason": "No matching WebSocket route",
+ }
+ )
+
class TestMiddleware(MicroPieTestCase):
"""Tests for HTTP and WebSocket middleware."""
async def test_http_middleware(self):
"""Test HTTP middleware before and after request."""
+
class TestMiddleware(HttpMiddleware):
async def before_request(self, request):
request.custom_data = "set_by_middleware"
return None
- async def after_request(self, request, status_code, response_body, extra_headers):
- return {"status_code": 201, "body": f"{response_body} + middleware", "headers": extra_headers}
+
+ async def after_request(
+ self, request, status_code, response_body, extra_headers
+ ):
+ return {
+ "status_code": 201,
+ "body": f"{response_body} + middleware",
+ "headers": extra_headers,
+ }
self.app.middlewares.append(TestMiddleware())
@@ -253,27 +322,35 @@ class TestMiddleware(MicroPieTestCase):
setattr(self.app, "index", index.__get__(self.app, App))
scope = self.create_mock_scope(path="/index")
- receive = AsyncMock(return_value={"type": "http.request", "body": b"", "more_body": False})
+ receive = AsyncMock(
+ return_value={"type": "http.request", "body": b"", "more_body": False}
+ )
send = AsyncMock()
await self.app(scope, receive, send)
- send.assert_any_call({
- "type": "http.response.start",
- "status": 201,
- "headers": [(b"Content-Type", b"text/html; charset=utf-8")]
- })
- send.assert_any_call({
- "type": "http.response.body",
- "body": b"Data: set_by_middleware + middleware",
- "more_body": False
- })
+ send.assert_any_call(
+ {
+ "type": "http.response.start",
+ "status": 201,
+ "headers": [(b"Content-Type", b"text/html; charset=utf-8")],
+ }
+ )
+ send.assert_any_call(
+ {
+ "type": "http.response.body",
+ "body": b"Data: set_by_middleware + middleware",
+ "more_body": False,
+ }
+ )
+
class TestResponseHandling(MicroPieTestCase):
"""Tests for response handling and edge cases."""
async def test_json_handling(self):
"""Test JSON request and response handling."""
+
async def json_handler(self):
return self.request.get_json
@@ -282,9 +359,15 @@ class TestResponseHandling(MicroPieTestCase):
scope = self.create_mock_scope(
path="/json_handler",
method="POST",
- headers=[(b"content-type", b"application/json")]
+ headers=[(b"content-type", b"application/json")],
+ )
+ receive = AsyncMock(
+ return_value={
+ "type": "http.request",
+ "body": b'{"key": "value"}',
+ "more_body": False,
+ }
)
- receive = AsyncMock(return_value={"type": "http.request", "body": b'{"key": "value"}', "more_body": False})
send = AsyncMock()
with patch("micropie.json") as mock_json:
@@ -295,49 +378,66 @@ class TestResponseHandling(MicroPieTestCase):
mock_json.loads.assert_called_once()
mock_json.dumps.assert_called_once()
- send.assert_any_call({
- "type": "http.response.start",
- "status": 200,
- "headers": [(b"Content-Type", b"application/json")]
- })
- send.assert_any_call({
- "type": "http.response.body",
- "body": b'{"key": "value"}',
- "more_body": False
- })
+ send.assert_any_call(
+ {
+ "type": "http.response.start",
+ "status": 200,
+ "headers": [(b"Content-Type", b"application/json")],
+ }
+ )
+ send.assert_any_call(
+ {
+ "type": "http.response.body",
+ "body": b'{"key": "value"}',
+ "more_body": False,
+ }
+ )
async def test_invalid_json(self):
"""Test handling of invalid JSON in POST request."""
scope = self.create_mock_scope(
path="/index",
method="POST",
- headers=[(b"content-type", b"application/json")]
+ headers=[(b"content-type", b"application/json")],
+ )
+ receive = AsyncMock(
+ return_value={
+ "type": "http.request",
+ "body": b"{invalid}",
+ "more_body": False,
+ }
)
- receive = AsyncMock(return_value={"type": "http.request", "body": b"{invalid}", "more_body": False})
send = AsyncMock()
await self.app(scope, receive, send)
- send.assert_any_call({
- "type": "http.response.start",
- "status": 400,
- "headers": [(b"Content-Type", b"text/html; charset=utf-8")]
- })
- send.assert_any_call({
- "type": "http.response.body",
- "body": b"400 Bad Request: Bad JSON",
- "more_body": False
- })
+ send.assert_any_call(
+ {
+ "type": "http.response.start",
+ "status": 400,
+ "headers": [(b"Content-Type", b"text/html; charset=utf-8")],
+ }
+ )
+ send.assert_any_call(
+ {
+ "type": "http.response.body",
+ "body": b"400 Bad Request: Bad JSON",
+ "more_body": False,
+ }
+ )
async def test_header_injection(self):
"""Test protection against header injection."""
+
async def index(self):
return 200, "Test", [("Bad-Header", "value\r\nInject: malicious")]
setattr(self.app, "index", index.__get__(self.app, App))
scope = self.create_mock_scope(path="/index")
- receive = AsyncMock(return_value={"type": "http.request", "body": b"", "more_body": False})
+ receive = AsyncMock(
+ return_value={"type": "http.request", "body": b"", "more_body": False}
+ )
send = AsyncMock()
await self.app(scope, receive, send)
@@ -353,13 +453,11 @@ class TestResponseHandling(MicroPieTestCase):
self.assertEqual(
start_call["headers"],
[(b"Content-Type", b"text/html; charset=utf-8")],
- "Malicious header should be filtered out"
+ "Malicious header should be filtered out",
+ )
+ send.assert_any_call(
+ {"type": "http.response.body", "body": b"Test", "more_body": False}
)
- send.assert_any_call({
- "type": "http.response.body",
- "body": b"Test",
- "more_body": False
- })
async def test_redirect(self):
"""Test redirect response generation."""
@@ -369,7 +467,10 @@ class TestResponseHandling(MicroPieTestCase):
self.assertEqual(status_code, 302, "Redirect should return 302 status")
self.assertEqual(body, "", "Redirect body should be empty")
self.assertIn(("Location", location), headers, "Location header should be set")
- self.assertIn(("X-Custom", "Value"), headers, "Extra headers should be included")
+ self.assertIn(
+ ("X-Custom", "Value"), headers, "Extra headers should be included"
+ )
+
class TestOptionalDependencies(MicroPieTestCase):
"""Tests for behavior with missing optional dependencies."""
@@ -380,47 +481,64 @@ class TestOptionalDependencies(MicroPieTestCase):
scope = self.create_mock_scope(
path="/index",
method="POST",
- headers=[(b"content-type", b"multipart/form-data; boundary=----boundary")]
+ headers=[
+ (b"content-type", b"multipart/form-data; boundary=----boundary")
+ ],
+ )
+ receive = AsyncMock(
+ return_value={"type": "http.request", "body": b"", "more_body": False}
)
- receive = AsyncMock(return_value={"type": "http.request", "body": b"", "more_body": False})
send = AsyncMock()
await self.app(scope, receive, send)
- send.assert_any_call({
- "type": "http.response.start",
- "status": 500,
- "headers": [(b"Content-Type", b"text/html; charset=utf-8")]
- })
- send.assert_any_call({
- "type": "http.response.body",
- "body": b"500 Internal Server Error",
- "more_body": False
- })
+ send.assert_any_call(
+ {
+ "type": "http.response.start",
+ "status": 500,
+ "headers": [(b"Content-Type", b"text/html; charset=utf-8")],
+ }
+ )
+ send.assert_any_call(
+ {
+ "type": "http.response.body",
+ "body": b"500 Internal Server Error",
+ "more_body": False,
+ }
+ )
async def test_no_jinja_installed(self):
"""Test behavior when Jinja2 is not installed."""
with patch("micropie.JINJA_INSTALLED", False):
+
async def index(self):
return await self._render_template("test.html")
+
setattr(self.app, "index", index.__get__(self.app, App))
scope = self.create_mock_scope(path="/index")
- receive = AsyncMock(return_value={"type": "http.request", "body": b"", "more_body": False})
+ receive = AsyncMock(
+ return_value={"type": "http.request", "body": b"", "more_body": False}
+ )
send = AsyncMock()
await self.app(scope, receive, send)
- send.assert_any_call({
- "type": "http.response.start",
- "status": 200,
- "headers": [(b"Content-Type", b"text/html; charset=utf-8")]
- })
- send.assert_any_call({
- "type": "http.response.body",
- "body": b"500 Internal Server Error: Jinja2 not installed.",
- "more_body": False
- })
+ send.assert_any_call(
+ {
+ "type": "http.response.start",
+ "status": 200,
+ "headers": [(b"Content-Type", b"text/html; charset=utf-8")],
+ }
+ )
+ send.assert_any_call(
+ {
+ "type": "http.response.body",
+ "body": b"500 Internal Server Error: Jinja2 not installed.",
+ "more_body": False,
+ }
+ )
+
if __name__ == "__main__":
unittest.main()