From afc49f945d2df8508872f1ccc6def5b95fa449f2 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Wed, 1 Apr 2026 20:14:53 -0700 Subject: [PATCH] [DeferredDict] add setdefault(), pop(), popitem(), copy() --- masque/test/test_utils.py | 48 ++++++++++++++++++++++++++++++++++++ masque/utils/deferreddict.py | 29 ++++++++++++++++++++++ 2 files changed, 77 insertions(+) diff --git a/masque/test/test_utils.py b/masque/test/test_utils.py index 45e347e..50b700b 100644 --- a/masque/test/test_utils.py +++ b/masque/test/test_utils.py @@ -1,8 +1,11 @@ +from pathlib import Path + import numpy from numpy.testing import assert_equal, assert_allclose from numpy import pi from ..utils import remove_duplicate_vertices, remove_colinear_vertices, poly_contains_points, rotation_matrix_2d, apply_transforms, DeferredDict +from ..file.utils import tmpfile def test_remove_duplicate_vertices() -> None: @@ -104,3 +107,48 @@ def test_deferred_dict_accessors_resolve_values_once() -> None: assert list(deferred.values()) == [7] assert list(deferred.items()) == [("x", 7)] assert calls == 1 + + +def test_deferred_dict_mutating_accessors_preserve_value_semantics() -> None: + calls = 0 + + def make_value() -> int: + nonlocal calls + calls += 1 + return 7 + + deferred = DeferredDict[str, int]() + + assert deferred.setdefault("x", 5) == 5 + assert deferred["x"] == 5 + + assert deferred.setdefault("y", make_value) == 7 + assert deferred["y"] == 7 + assert calls == 1 + + copy_deferred = deferred.copy() + assert isinstance(copy_deferred, DeferredDict) + assert copy_deferred["x"] == 5 + assert copy_deferred["y"] == 7 + assert calls == 1 + + assert deferred.pop("x") == 5 + assert deferred.pop("missing", 9) == 9 + assert deferred.popitem() == ("y", 7) + + +def test_tmpfile_cleans_up_on_exception(tmp_path: Path) -> None: + target = tmp_path / "out.txt" + temp_path = None + + try: + with tmpfile(target) as stream: + temp_path = Path(stream.name) + stream.write(b"hello") + raise RuntimeError("boom") + except RuntimeError: + pass + + assert temp_path is not None + assert not target.exists() + assert not temp_path.exists() diff --git a/masque/utils/deferreddict.py b/masque/utils/deferreddict.py index def9b10..70893c0 100644 --- a/masque/utils/deferreddict.py +++ b/masque/utils/deferreddict.py @@ -5,6 +5,7 @@ from functools import lru_cache Key = TypeVar('Key') Value = TypeVar('Value') +_MISSING = object() class DeferredDict(dict, Generic[Key, Value]): @@ -46,6 +47,15 @@ class DeferredDict(dict, Generic[Key, Value]): return default return self[key] + def setdefault(self, key: Key, default: Value | Callable[[], Value] | None = None) -> Value | None: + if key in self: + return self[key] + if callable(default): + self[key] = default + else: + self.set_const(key, default) + return self[key] + def items(self) -> Iterator[tuple[Key, Value]]: for key in self.keys(): yield key, self[key] @@ -65,6 +75,25 @@ class DeferredDict(dict, Generic[Key, Value]): else: self.set_const(k, v) + def pop(self, key: Key, default: Value | object = _MISSING) -> Value: + if key in self: + value = self[key] + dict.pop(self, key) + return value + if default is _MISSING: + raise KeyError(key) + return default # type: ignore[return-value] + + def popitem(self) -> tuple[Key, Value]: + key, value_func = dict.popitem(self) + return key, value_func() + + def copy(self) -> 'DeferredDict[Key, Value]': + new = DeferredDict[Key, Value]() + for key in self.keys(): + dict.__setitem__(new, key, dict.__getitem__(self, key)) + return new + def __repr__(self) -> str: return ''