make AsyncPickleDB fully async with context manager support

Commit f20b0d0 · patx · 2025-08-11T12:49:42-04:00

Changeset
f20b0d0797ab41d45e8c7265366ec301f4cdd318
Parents
46ab99ffad71ea4fd8c29adcb3aea076e24a1865

View source at this commit

make AsyncPickleDB fully async with context manager support

- Added async context manager methods (__aenter__, __aexit__)
  to automatically load on entry and save on exit.
- Converted file loading (_load) into fully async `aload()`
  using aiofiles for non-blocking I/O.
- Kept atomic save behavior with async file writes and os.replace
  offloaded via asyncio.to_thread.
- Ensured all CRUD methods remain protected by asyncio.Lock
  for safe concurrent access.

Comments

No comments yet.

Log in to comment

Diff

diff --git a/pickledb.py b/pickledb.py
index 686262e..ff6daca 100644
--- a/pickledb.py
+++ b/pickledb.py
@@ -30,7 +30,6 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 """
 import asyncio
 import os
-
 import aiofiles
 import orjson
 
@@ -196,103 +195,74 @@ class PickleDB:
         return list(self.db.keys())
 
 
-class AsyncPickleDB(PickleDB):
+class AsyncPickleDB:
+    """
+    A fully asynchronous orjson-based key-value store.
+    Provides async load, save, and CRUD operations with file locking.
+    """
 
     def __init__(self, location):
-        super().__init__(location)
+        self.location = os.path.expanduser(location)
         self._lock = asyncio.Lock()
+        self.db = {}
 
-    async def aset(self, key, value):
-        """
-        Async version of the set method.
-
-        Args:
-            key (any): The key to set. If the key is not a string, it
-                       will be converted to a string.
-            value (any): The value to associate with the key.
-
-        Behavior:
-            - If the key already exists, its value will be updated.
-            - If the key does not exist, it will be added to the
-              database.
-
-        Returns:
-            bool: True if the operation succeeds.
-        """
-        async with self._lock:
-            self.db[str(key)] = value
-            return True
-
-    async def aget(self, key):
-        """
-        Async version of the get method.
-
-        Args:
-            key (any): The key to retrieve. If the key is not a
-                       string, it will be converted to a string.
+    async def __aenter__(self):
+        await self.aload()
+        return self
 
-        Returns:
-            any: The value associated with the key, or None if the
-            key does not exist.
-        """
-        async with self._lock:
-            return self.db.get(str(key))
+    async def __aexit__(self, exc_type, exc_val, exc_tb):
+        if exc_type is None:
+            await self.asave()
+        return False  # Do not suppress exceptions
 
-    async def aremove(self, key):
+    async def aload(self):
         """
-        Async version of the remove method.
-
-        Args:
-            key (any): The key to delete. If the key is not a string,
-                       it will be converted to a string.
-
-        Returns:
-            bool: True if the key was deleted, False if the key does
-                  not exist.
+        Load data from the JSON file if it exists.
         """
-        async with self._lock:
-            return self.db.pop(str(key), None) is not None
+        if os.path.exists(self.location) and os.path.getsize(self.location) > 0:
+            try:
+                async with aiofiles.open(self.location, "rb") as f:
+                    content = await f.read()
+                self.db = orjson.loads(content)
+            except Exception as e:
+                raise RuntimeError(f"{e}\nFailed to load database.")
+        else:
+            self.db = {}
 
-    async def asave(self):
+    async def asave(self, option=0):
         """
-        Async version of the save method.
-
-        Behavior:
-            - Writes to a temporary file and replaces the
-              original file only after the write is successful,
-              ensuring data integrity.
-
-        Returns:
-            bool: True if save was successful, False if not.
+        Save the database to file atomically.
         """
         temp_location = f"{self.location}.tmp"
         async with self._lock:
             try:
                 async with aiofiles.open(temp_location, "wb") as temp_file:
-                    await temp_file.write(orjson.dumps(self.db))
+                    await temp_file.write(orjson.dumps(self.db, option=option))
                 await asyncio.to_thread(os.replace, temp_location, self.location)
                 return True
             except Exception as e:
                 print(f"Failed to save database: {e}")
                 return False
 
-    async def aall(self):
-        """
-        Async version of the all method.
+    async def aset(self, key, value):
+        async with self._lock:
+            self.db[str(key)] = value
+            return True
 
-        Returns:
-            list: A list of all keys.
-        """
+    async def aget(self, key):
+        async with self._lock:
+            return self.db.get(str(key))
+
+    async def aremove(self, key):
+        async with self._lock:
+            return self.db.pop(str(key), None) is not None
+
+    async def aall(self):
         async with self._lock:
             return list(self.db.keys())
 
     async def apurge(self):
-        """
-        Async version of the purge method.
-
-        Returns:
-            bool: True if the operation succeeds.
-        """
         async with self._lock:
             self.db.clear()
             return True
+