[DeferredDict] add setdefault(), pop(), popitem(), copy()
This commit is contained in:
parent
ce46cc18dc
commit
afc49f945d
2 changed files with 77 additions and 0 deletions
|
|
@ -1,8 +1,11 @@
|
|||
from pathlib import Path
|
||||
|
||||
import numpy
|
||||
from numpy.testing import assert_equal, assert_allclose
|
||||
from numpy import pi
|
||||
|
||||
from ..utils import remove_duplicate_vertices, remove_colinear_vertices, poly_contains_points, rotation_matrix_2d, apply_transforms, DeferredDict
|
||||
from ..file.utils import tmpfile
|
||||
|
||||
|
||||
def test_remove_duplicate_vertices() -> None:
|
||||
|
|
@ -104,3 +107,48 @@ def test_deferred_dict_accessors_resolve_values_once() -> None:
|
|||
assert list(deferred.values()) == [7]
|
||||
assert list(deferred.items()) == [("x", 7)]
|
||||
assert calls == 1
|
||||
|
||||
|
||||
def test_deferred_dict_mutating_accessors_preserve_value_semantics() -> None:
|
||||
calls = 0
|
||||
|
||||
def make_value() -> int:
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
return 7
|
||||
|
||||
deferred = DeferredDict[str, int]()
|
||||
|
||||
assert deferred.setdefault("x", 5) == 5
|
||||
assert deferred["x"] == 5
|
||||
|
||||
assert deferred.setdefault("y", make_value) == 7
|
||||
assert deferred["y"] == 7
|
||||
assert calls == 1
|
||||
|
||||
copy_deferred = deferred.copy()
|
||||
assert isinstance(copy_deferred, DeferredDict)
|
||||
assert copy_deferred["x"] == 5
|
||||
assert copy_deferred["y"] == 7
|
||||
assert calls == 1
|
||||
|
||||
assert deferred.pop("x") == 5
|
||||
assert deferred.pop("missing", 9) == 9
|
||||
assert deferred.popitem() == ("y", 7)
|
||||
|
||||
|
||||
def test_tmpfile_cleans_up_on_exception(tmp_path: Path) -> None:
|
||||
target = tmp_path / "out.txt"
|
||||
temp_path = None
|
||||
|
||||
try:
|
||||
with tmpfile(target) as stream:
|
||||
temp_path = Path(stream.name)
|
||||
stream.write(b"hello")
|
||||
raise RuntimeError("boom")
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
assert temp_path is not None
|
||||
assert not target.exists()
|
||||
assert not temp_path.exists()
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from functools import lru_cache
|
|||
|
||||
Key = TypeVar('Key')
|
||||
Value = TypeVar('Value')
|
||||
_MISSING = object()
|
||||
|
||||
|
||||
class DeferredDict(dict, Generic[Key, Value]):
|
||||
|
|
@ -46,6 +47,15 @@ class DeferredDict(dict, Generic[Key, Value]):
|
|||
return default
|
||||
return self[key]
|
||||
|
||||
def setdefault(self, key: Key, default: Value | Callable[[], Value] | None = None) -> Value | None:
|
||||
if key in self:
|
||||
return self[key]
|
||||
if callable(default):
|
||||
self[key] = default
|
||||
else:
|
||||
self.set_const(key, default)
|
||||
return self[key]
|
||||
|
||||
def items(self) -> Iterator[tuple[Key, Value]]:
|
||||
for key in self.keys():
|
||||
yield key, self[key]
|
||||
|
|
@ -65,6 +75,25 @@ class DeferredDict(dict, Generic[Key, Value]):
|
|||
else:
|
||||
self.set_const(k, v)
|
||||
|
||||
def pop(self, key: Key, default: Value | object = _MISSING) -> Value:
|
||||
if key in self:
|
||||
value = self[key]
|
||||
dict.pop(self, key)
|
||||
return value
|
||||
if default is _MISSING:
|
||||
raise KeyError(key)
|
||||
return default # type: ignore[return-value]
|
||||
|
||||
def popitem(self) -> tuple[Key, Value]:
|
||||
key, value_func = dict.popitem(self)
|
||||
return key, value_func()
|
||||
|
||||
def copy(self) -> 'DeferredDict[Key, Value]':
|
||||
new = DeferredDict[Key, Value]()
|
||||
for key in self.keys():
|
||||
dict.__setitem__(new, key, dict.__getitem__(self, key))
|
||||
return new
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return '<DeferredDict with keys ' + repr(set(self.keys())) + '>'
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue