patx/mongokv

change get method to use sentinal default instead of None

Commit 96338cb · patx · 2025-12-12T14:39:40-05:00

Changeset
96338cba6f0d959b41e1bcf1d8552b6e64ff2927
Parents
c08fa1beb958a5c9d39fcc4d0159d2a388cf4b9c

View source at this commit

Comments

No comments yet.

Log in to comment

Diff

diff --git a/mkvdb.py b/mkvdb.py
index efd7519..c04b75d 100644
--- a/mkvdb.py
+++ b/mkvdb.py
@@ -8,7 +8,10 @@ import asyncio
 from typing import Any
 
 from pymongo import MongoClient, AsyncMongoClient
-from bson import ObjectId 
+from bson import ObjectId
+
+
+MISSING = object()
 
 
 def in_async() -> bool:
@@ -43,7 +46,7 @@ class Mkv:
             async def _aset() -> str:
                 if key is None:
                     new_id = str(ObjectId())
-                    await self.collection.insert_one({"_id": new_id, 
+                    await self.collection.insert_one({"_id": new_id,
                         "value": value})
                     return new_id
                 key_str = str(key)
@@ -60,19 +63,23 @@ class Mkv:
             {"$set": {"value": value}},upsert=True,)
         return key_str
 
-    def get(self, key: str, default: Any | None = None) -> Any | None:
-        """Get the value for a key. """
+    def get(self, key: str, default: Any = MISSING) -> Any:
+        """Get the value for a key."""
         if in_async():
-            async def _aget() -> Any | None:
+            async def _aget() -> Any:
                 doc = await self.collection.find_one({"_id": str(key)})
                 if doc is None:
+                    if default is MISSING:
+                        raise KeyError(key)
                     return default
-                return doc.get("value", default)
+                return doc.get("value")
             return _aget()
         doc = self._sync_collection.find_one({"_id": str(key)})
         if doc is None:
+            if default is MISSING:
+                raise KeyError(key)
             return default
-        return doc.get("value", default)
+        return doc.get("value")
 
     def remove(self, key: str) -> bool:
         """
diff --git a/test_mkvdb.py b/test_mkvdb.py
index 198441d..0b30cc2 100644
--- a/test_mkvdb.py
+++ b/test_mkvdb.py
@@ -79,8 +79,8 @@ async def test_get_missing_returns_default_async(mkv: Mkv):
 
 @pytest.mark.asyncio
 async def test_get_missing_default_none_async(mkv: Mkv):
-    value = await mkv.get("missing")
-    assert value is None
+    with pytest.raises(KeyError):
+        await mkv.get("missing")
 
 
 @pytest.mark.asyncio
@@ -98,8 +98,8 @@ async def test_remove_existing_key_async(mkv: Mkv):
     assert removed is True
 
     # Confirm it’s gone
-    value = await mkv.get("temp")
-    assert value is None
+    with pytest.raises(KeyError):
+        await mkv.get("temp")
 
 
 @pytest.mark.asyncio
@@ -161,8 +161,29 @@ async def test_close_does_not_throw_async(mkv: Mkv):
     # Don't assert behavior after close (Motor generally allows it but it's not required)
 
 
[email protected]
+async def test_get_missing_default_none_explicit_async(mkv: Mkv):
+    value = await mkv.get("missing", default=None)
+    assert value is None
+
+
[email protected]
+async def test_get_missing_raises_keyerror_async(mkv: Mkv):
+    with pytest.raises(KeyError):
+        await mkv.get("missing")
+        
+        
 # ---------- Sync tests (dualmethod behavior) ----------
 
+def test_sync_get_missing_raises_keyerror(mkv_sync: Mkv):
+    with pytest.raises(KeyError):
+        mkv_sync.get("missing")
+
+def test_sync_get_missing_default_none_explicit(mkv_sync: Mkv):
+    value = mkv_sync.get("missing", default=None)
+    assert value is None
+    
+
 def test_sync_set_and_get(mkv_sync: Mkv):
     mkv_sync.set("foo", "bar")
     value = mkv_sync.get("foo")
@@ -178,13 +199,14 @@ def test_sync_remove_and_purge(mkv_sync: Mkv):
     mkv_sync.set("a", 1)
     mkv_sync.set("b", 2)
 
-    # Remove one key
     removed = mkv_sync.remove("a")
     assert removed is True
-    assert mkv_sync.get("a") is None
+
+    with pytest.raises(KeyError):
+        mkv_sync.get("a")
+
     assert mkv_sync.get("b") == 2
 
-    # Purge everything
     purged = mkv_sync.purge()
     assert purged is True
     assert mkv_sync.all() == []