diff --git a/masque/library.py b/masque/library.py index df637af..db9dda5 100644 --- a/masque/library.py +++ b/masque/library.py @@ -538,6 +538,7 @@ class ILibraryView(Mapping[str, 'Pattern'], metaclass=ABCMeta): raise LibraryError('visit_* functions returned a new `Pattern` object' ' but no top-level name was provided in `hierarchy`') + del cast('ILibrary', self)[name] cast('ILibrary', self)[name] = pattern return self diff --git a/masque/test/test_library.py b/masque/test/test_library.py index d035db6..56ee3d7 100644 --- a/masque/test/test_library.py +++ b/masque/test/test_library.py @@ -221,6 +221,54 @@ def test_library_rename() -> None: assert "old" not in lib["parent"].refs +def test_library_dfs_can_replace_existing_patterns() -> None: + lib = Library() + child = Pattern() + lib["child"] = child + top = Pattern() + top.ref("child") + lib["top"] = top + + replacement_top = Pattern(ports={"T": Port((1, 2), 0)}) + replacement_child = Pattern(ports={"C": Port((3, 4), 0)}) + + def visit_after(pattern: Pattern, hierarchy: tuple[str | None, ...], **kwargs) -> Pattern: # noqa: ARG001 + if hierarchy[-1] == "child": + return replacement_child + if hierarchy[-1] == "top": + return replacement_top + return pattern + + lib.dfs(lib["top"], visit_after=visit_after, hierarchy=("top",), transform=True) + + assert lib["top"] is replacement_top + assert lib["child"] is replacement_child + + +def test_lazy_library_dfs_can_replace_existing_patterns() -> None: + lib = LazyLibrary() + lib["child"] = lambda: Pattern() + lib["top"] = lambda: Pattern(refs={"child": []}) + + top = lib["top"] + top.ref("child") + + replacement_top = Pattern(ports={"T": Port((1, 2), 0)}) + replacement_child = Pattern(ports={"C": Port((3, 4), 0)}) + + def visit_after(pattern: Pattern, hierarchy: tuple[str | None, ...], **kwargs) -> Pattern: # noqa: ARG001 + if hierarchy[-1] == "child": + return replacement_child + if hierarchy[-1] == "top": + return replacement_top + return pattern + + lib.dfs(top, visit_after=visit_after, hierarchy=("top",), transform=True) + + assert lib["top"] is replacement_top + assert lib["child"] is replacement_child + + def test_library_add_no_duplicates_respects_mutate_other_false() -> None: src_pat = Pattern(ports={"A": Port((0, 0), 0)}) lib = Library({"a": Pattern()})