[DeferredDict] add setdefault(), pop(), popitem(), copy()
This commit is contained in:
parent
ce46cc18dc
commit
afc49f945d
2 changed files with 77 additions and 0 deletions
|
|
@ -1,8 +1,11 @@
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
from numpy.testing import assert_equal, assert_allclose
|
from numpy.testing import assert_equal, assert_allclose
|
||||||
from numpy import pi
|
from numpy import pi
|
||||||
|
|
||||||
from ..utils import remove_duplicate_vertices, remove_colinear_vertices, poly_contains_points, rotation_matrix_2d, apply_transforms, DeferredDict
|
from ..utils import remove_duplicate_vertices, remove_colinear_vertices, poly_contains_points, rotation_matrix_2d, apply_transforms, DeferredDict
|
||||||
|
from ..file.utils import tmpfile
|
||||||
|
|
||||||
|
|
||||||
def test_remove_duplicate_vertices() -> None:
|
def test_remove_duplicate_vertices() -> None:
|
||||||
|
|
@ -104,3 +107,48 @@ def test_deferred_dict_accessors_resolve_values_once() -> None:
|
||||||
assert list(deferred.values()) == [7]
|
assert list(deferred.values()) == [7]
|
||||||
assert list(deferred.items()) == [("x", 7)]
|
assert list(deferred.items()) == [("x", 7)]
|
||||||
assert calls == 1
|
assert calls == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_deferred_dict_mutating_accessors_preserve_value_semantics() -> None:
|
||||||
|
calls = 0
|
||||||
|
|
||||||
|
def make_value() -> int:
|
||||||
|
nonlocal calls
|
||||||
|
calls += 1
|
||||||
|
return 7
|
||||||
|
|
||||||
|
deferred = DeferredDict[str, int]()
|
||||||
|
|
||||||
|
assert deferred.setdefault("x", 5) == 5
|
||||||
|
assert deferred["x"] == 5
|
||||||
|
|
||||||
|
assert deferred.setdefault("y", make_value) == 7
|
||||||
|
assert deferred["y"] == 7
|
||||||
|
assert calls == 1
|
||||||
|
|
||||||
|
copy_deferred = deferred.copy()
|
||||||
|
assert isinstance(copy_deferred, DeferredDict)
|
||||||
|
assert copy_deferred["x"] == 5
|
||||||
|
assert copy_deferred["y"] == 7
|
||||||
|
assert calls == 1
|
||||||
|
|
||||||
|
assert deferred.pop("x") == 5
|
||||||
|
assert deferred.pop("missing", 9) == 9
|
||||||
|
assert deferred.popitem() == ("y", 7)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tmpfile_cleans_up_on_exception(tmp_path: Path) -> None:
|
||||||
|
target = tmp_path / "out.txt"
|
||||||
|
temp_path = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
with tmpfile(target) as stream:
|
||||||
|
temp_path = Path(stream.name)
|
||||||
|
stream.write(b"hello")
|
||||||
|
raise RuntimeError("boom")
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert temp_path is not None
|
||||||
|
assert not target.exists()
|
||||||
|
assert not temp_path.exists()
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ from functools import lru_cache
|
||||||
|
|
||||||
Key = TypeVar('Key')
|
Key = TypeVar('Key')
|
||||||
Value = TypeVar('Value')
|
Value = TypeVar('Value')
|
||||||
|
_MISSING = object()
|
||||||
|
|
||||||
|
|
||||||
class DeferredDict(dict, Generic[Key, Value]):
|
class DeferredDict(dict, Generic[Key, Value]):
|
||||||
|
|
@ -46,6 +47,15 @@ class DeferredDict(dict, Generic[Key, Value]):
|
||||||
return default
|
return default
|
||||||
return self[key]
|
return self[key]
|
||||||
|
|
||||||
|
def setdefault(self, key: Key, default: Value | Callable[[], Value] | None = None) -> Value | None:
|
||||||
|
if key in self:
|
||||||
|
return self[key]
|
||||||
|
if callable(default):
|
||||||
|
self[key] = default
|
||||||
|
else:
|
||||||
|
self.set_const(key, default)
|
||||||
|
return self[key]
|
||||||
|
|
||||||
def items(self) -> Iterator[tuple[Key, Value]]:
|
def items(self) -> Iterator[tuple[Key, Value]]:
|
||||||
for key in self.keys():
|
for key in self.keys():
|
||||||
yield key, self[key]
|
yield key, self[key]
|
||||||
|
|
@ -65,6 +75,25 @@ class DeferredDict(dict, Generic[Key, Value]):
|
||||||
else:
|
else:
|
||||||
self.set_const(k, v)
|
self.set_const(k, v)
|
||||||
|
|
||||||
|
def pop(self, key: Key, default: Value | object = _MISSING) -> Value:
|
||||||
|
if key in self:
|
||||||
|
value = self[key]
|
||||||
|
dict.pop(self, key)
|
||||||
|
return value
|
||||||
|
if default is _MISSING:
|
||||||
|
raise KeyError(key)
|
||||||
|
return default # type: ignore[return-value]
|
||||||
|
|
||||||
|
def popitem(self) -> tuple[Key, Value]:
|
||||||
|
key, value_func = dict.popitem(self)
|
||||||
|
return key, value_func()
|
||||||
|
|
||||||
|
def copy(self) -> 'DeferredDict[Key, Value]':
|
||||||
|
new = DeferredDict[Key, Value]()
|
||||||
|
for key in self.keys():
|
||||||
|
dict.__setitem__(new, key, dict.__getitem__(self, key))
|
||||||
|
return new
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return '<DeferredDict with keys ' + repr(set(self.keys())) + '>'
|
return '<DeferredDict with keys ' + repr(set(self.keys())) + '>'
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue