patx/pickledb
update tests with better async coverage
Commit 50d2242 · patx · 2025-02-10T14:45:03-05:00
Comments
No comments yet.
Diff
diff --git a/tests.py b/tests.py
index 0943f39..798e69a 100644
--- a/tests.py
+++ b/tests.py
@@ -2,7 +2,13 @@ import unittest
import os
import time
import signal
-from pickledb import PickleDB # Adjust the import path if needed
+import asyncio
+import aiofiles
+import orjson
+
+# Adjust the import path if needed. For example, if 'pickledb' is your own module,
+# ensure the relative or absolute path matches your project structure.
+from pickledb import PickleDB, AsyncPickleDB
class TestPickleDB(unittest.TestCase):
@@ -114,6 +120,94 @@ class TestPickleDB(unittest.TestCase):
self.assertIsNone(self.db.get("123"))
+class TestAsyncPickleDB(unittest.IsolatedAsyncioTestCase):
+ async def asyncSetUp(self):
+ """Set up an AsyncPickleDB instance with a real file."""
+ self.test_file = "test_async_pickledb.json"
+ if os.path.exists(self.test_file):
+ os.remove(self.test_file)
+ self.db = AsyncPickleDB(self.test_file)
+
+ async def asyncTearDown(self):
+ """Clean up after async tests by removing the test file."""
+ if os.path.exists(self.test_file):
+ os.remove(self.test_file)
+
+ async def test_aset_and_aget(self):
+ """Test setting and retrieving a key-value pair asynchronously."""
+ await self.db.aset("key1", "async_value1")
+ value = await self.db.aget("key1")
+ self.assertEqual(value, "async_value1")
+
+ async def test_aget_nonexistent_key(self):
+ """Test retrieving a key that does not exist asynchronously."""
+ value = await self.db.aget("nonexistent")
+ self.assertIsNone(value)
+
+ async def test_aremove_key(self):
+ """Test removing a key-value pair asynchronously."""
+ await self.db.aset("key1", "to_remove")
+ removed = await self.db.aremove("key1")
+ self.assertTrue(removed)
+ value = await self.db.aget("key1")
+ self.assertIsNone(value)
+
+ async def test_aremove_nonexistent_key(self):
+ """Test removing a key that does not exist asynchronously."""
+ removed = await self.db.aremove("nonexistent")
+ self.assertFalse(removed)
+
+ async def test_apurge(self):
+ """Test purging all keys asynchronously."""
+ await self.db.aset("key1", "val1")
+ await self.db.aset("key2", "val2")
+ await self.db.apurge()
+ keys = await self.db.aall()
+ self.assertEqual(keys, [])
+
+ async def test_aall_keys(self):
+ """Test retrieving all keys asynchronously."""
+ await self.db.aset("keyA", "valA")
+ await self.db.aset("keyB", "valB")
+ keys = await self.db.aall()
+ self.assertListEqual(sorted(keys), ["keyA", "keyB"])
+
+ async def test_asave_and_reload(self):
+ """
+ Test dumping (asave) the async database to disk and reloading it
+ by creating a new AsyncPickleDB instance.
+ """
+ await self.db.aset("async_key", "async_val")
+ await self.db.asave()
+
+ # Create a new AsyncPickleDB instance to verify persistence,
+ # then use its inherited synchronous `get` method.
+ new_db = AsyncPickleDB(self.test_file)
+ self.assertEqual(new_db.get("async_key"), "async_val")
+
+ async def test_aset_non_string_key(self):
+ """Test setting a non-string key asynchronously."""
+ await self.db.aset(123, "val123")
+ value = await self.db.aget("123")
+ self.assertEqual(value, "val123")
+
+ async def test_concurrent_access(self):
+ """Test concurrent async set operations under the lock."""
+ async def set_values(db, start, end):
+ for i in range(start, end):
+ await db.aset(f"key{i}", f"value{i}")
+
+ # Run two coroutines concurrently
+ await asyncio.gather(
+ set_values(self.db, 0, 50),
+ set_values(self.db, 50, 100)
+ )
+
+ # Verify data integrity
+ for i in range(100):
+ val = await self.db.aget(f"key{i}")
+ self.assertEqual(val, f"value{i}")
+
+
if __name__ == "__main__":
unittest.main()
-