[DeferredDict] add setdefault(), pop(), popitem(), copy()

This commit is contained in:
Jan Petykiewicz 2026-04-01 20:14:53 -07:00
commit afc49f945d
2 changed files with 77 additions and 0 deletions

View file

@ -1,8 +1,11 @@
from pathlib import Path
import numpy
from numpy.testing import assert_equal, assert_allclose
from numpy import pi
from ..utils import remove_duplicate_vertices, remove_colinear_vertices, poly_contains_points, rotation_matrix_2d, apply_transforms, DeferredDict
from ..file.utils import tmpfile
def test_remove_duplicate_vertices() -> None:
@ -104,3 +107,48 @@ def test_deferred_dict_accessors_resolve_values_once() -> None:
assert list(deferred.values()) == [7]
assert list(deferred.items()) == [("x", 7)]
assert calls == 1
def test_deferred_dict_mutating_accessors_preserve_value_semantics() -> None:
calls = 0
def make_value() -> int:
nonlocal calls
calls += 1
return 7
deferred = DeferredDict[str, int]()
assert deferred.setdefault("x", 5) == 5
assert deferred["x"] == 5
assert deferred.setdefault("y", make_value) == 7
assert deferred["y"] == 7
assert calls == 1
copy_deferred = deferred.copy()
assert isinstance(copy_deferred, DeferredDict)
assert copy_deferred["x"] == 5
assert copy_deferred["y"] == 7
assert calls == 1
assert deferred.pop("x") == 5
assert deferred.pop("missing", 9) == 9
assert deferred.popitem() == ("y", 7)
def test_tmpfile_cleans_up_on_exception(tmp_path: Path) -> None:
target = tmp_path / "out.txt"
temp_path = None
try:
with tmpfile(target) as stream:
temp_path = Path(stream.name)
stream.write(b"hello")
raise RuntimeError("boom")
except RuntimeError:
pass
assert temp_path is not None
assert not target.exists()
assert not temp_path.exists()

View file

@ -5,6 +5,7 @@ from functools import lru_cache
Key = TypeVar('Key')
Value = TypeVar('Value')
_MISSING = object()
class DeferredDict(dict, Generic[Key, Value]):
@ -46,6 +47,15 @@ class DeferredDict(dict, Generic[Key, Value]):
return default
return self[key]
def setdefault(self, key: Key, default: Value | Callable[[], Value] | None = None) -> Value | None:
if key in self:
return self[key]
if callable(default):
self[key] = default
else:
self.set_const(key, default)
return self[key]
def items(self) -> Iterator[tuple[Key, Value]]:
for key in self.keys():
yield key, self[key]
@ -65,6 +75,25 @@ class DeferredDict(dict, Generic[Key, Value]):
else:
self.set_const(k, v)
def pop(self, key: Key, default: Value | object = _MISSING) -> Value:
if key in self:
value = self[key]
dict.pop(self, key)
return value
if default is _MISSING:
raise KeyError(key)
return default # type: ignore[return-value]
def popitem(self) -> tuple[Key, Value]:
key, value_func = dict.popitem(self)
return key, value_func()
def copy(self) -> 'DeferredDict[Key, Value]':
new = DeferredDict[Key, Value]()
for key in self.keys():
dict.__setitem__(new, key, dict.__getitem__(self, key))
return new
def __repr__(self) -> str:
return '<DeferredDict with keys ' + repr(set(self.keys())) + '>'