update tests with better async coverage

Commit 50d2242 · patx · 2025-02-10T14:45:03-05:00

Changeset
50d2242ed7873ae87f57ef0f8829379b97998ef8
Parents
061b13868025c4f5b5ab512d67a5963cdde6b0e5

View source at this commit

Comments

No comments yet.

Log in to comment

Diff

diff --git a/tests.py b/tests.py
index 0943f39..798e69a 100644
--- a/tests.py
+++ b/tests.py
@@ -2,7 +2,13 @@ import unittest
 import os
 import time
 import signal
-from pickledb import PickleDB  # Adjust the import path if needed
+import asyncio
+import aiofiles
+import orjson
+
+# Adjust the import path if needed. For example, if 'pickledb' is your own module,
+# ensure the relative or absolute path matches your project structure.
+from pickledb import PickleDB, AsyncPickleDB
 
 
 class TestPickleDB(unittest.TestCase):
@@ -114,6 +120,94 @@ class TestPickleDB(unittest.TestCase):
         self.assertIsNone(self.db.get("123"))
 
 
+class TestAsyncPickleDB(unittest.IsolatedAsyncioTestCase):
+    async def asyncSetUp(self):
+        """Set up an AsyncPickleDB instance with a real file."""
+        self.test_file = "test_async_pickledb.json"
+        if os.path.exists(self.test_file):
+            os.remove(self.test_file)
+        self.db = AsyncPickleDB(self.test_file)
+
+    async def asyncTearDown(self):
+        """Clean up after async tests by removing the test file."""
+        if os.path.exists(self.test_file):
+            os.remove(self.test_file)
+
+    async def test_aset_and_aget(self):
+        """Test setting and retrieving a key-value pair asynchronously."""
+        await self.db.aset("key1", "async_value1")
+        value = await self.db.aget("key1")
+        self.assertEqual(value, "async_value1")
+
+    async def test_aget_nonexistent_key(self):
+        """Test retrieving a key that does not exist asynchronously."""
+        value = await self.db.aget("nonexistent")
+        self.assertIsNone(value)
+
+    async def test_aremove_key(self):
+        """Test removing a key-value pair asynchronously."""
+        await self.db.aset("key1", "to_remove")
+        removed = await self.db.aremove("key1")
+        self.assertTrue(removed)
+        value = await self.db.aget("key1")
+        self.assertIsNone(value)
+
+    async def test_aremove_nonexistent_key(self):
+        """Test removing a key that does not exist asynchronously."""
+        removed = await self.db.aremove("nonexistent")
+        self.assertFalse(removed)
+
+    async def test_apurge(self):
+        """Test purging all keys asynchronously."""
+        await self.db.aset("key1", "val1")
+        await self.db.aset("key2", "val2")
+        await self.db.apurge()
+        keys = await self.db.aall()
+        self.assertEqual(keys, [])
+
+    async def test_aall_keys(self):
+        """Test retrieving all keys asynchronously."""
+        await self.db.aset("keyA", "valA")
+        await self.db.aset("keyB", "valB")
+        keys = await self.db.aall()
+        self.assertListEqual(sorted(keys), ["keyA", "keyB"])
+
+    async def test_asave_and_reload(self):
+        """
+        Test dumping (asave) the async database to disk and reloading it
+        by creating a new AsyncPickleDB instance.
+        """
+        await self.db.aset("async_key", "async_val")
+        await self.db.asave()
+
+        # Create a new AsyncPickleDB instance to verify persistence,
+        # then use its inherited synchronous `get` method.
+        new_db = AsyncPickleDB(self.test_file)
+        self.assertEqual(new_db.get("async_key"), "async_val")
+
+    async def test_aset_non_string_key(self):
+        """Test setting a non-string key asynchronously."""
+        await self.db.aset(123, "val123")
+        value = await self.db.aget("123")
+        self.assertEqual(value, "val123")
+
+    async def test_concurrent_access(self):
+        """Test concurrent async set operations under the lock."""
+        async def set_values(db, start, end):
+            for i in range(start, end):
+                await db.aset(f"key{i}", f"value{i}")
+
+        # Run two coroutines concurrently
+        await asyncio.gather(
+            set_values(self.db, 0, 50),
+            set_values(self.db, 50, 100)
+        )
+
+        # Verify data integrity
+        for i in range(100):
+            val = await self.db.aget(f"key{i}")
+            self.assertEqual(val, f"value{i}")
+
+
 if __name__ == "__main__":
     unittest.main()
-