diff --git a/masque/pattern.py b/masque/pattern.py index 287ca92..b670493 100644 --- a/masque/pattern.py +++ b/masque/pattern.py @@ -1383,6 +1383,15 @@ class Pattern(PortList, AnnotatableImpl, Mirrorable): if not skip_port_check: self.check_ports(other.ports.keys(), map_in=None, map_out=port_map) + if not skip_geometry: + if append: + if isinstance(other, Abstract): + raise PatternError('Must provide a full `Pattern` (not an `Abstract`) when appending!') + else: + if isinstance(other, Pattern): + raise PatternError('Must provide an `Abstract` (not a `Pattern`) when creating a reference. ' + 'Use `append=True` if you intended to append the full geometry.') + ports = {} for name, port in other.ports.items(): new_name = port_map.get(name, name) @@ -1404,8 +1413,6 @@ class Pattern(PortList, AnnotatableImpl, Mirrorable): return self if append: - if isinstance(other, Abstract): - raise PatternError('Must provide a full `Pattern` (not an `Abstract`) when appending!') other_copy = other.deepcopy() other_copy.ports.clear() if mirrored: @@ -1414,9 +1421,6 @@ class Pattern(PortList, AnnotatableImpl, Mirrorable): other_copy.translate_elements(offset) self.append(other_copy) else: - if isinstance(other, Pattern): - raise PatternError('Must provide an `Abstract` (not a `Pattern`) when creating a reference. ' - 'Use `append=True` if you intended to append the full geometry.') ref = Ref(mirrored=mirrored) ref.rotate_around(pivot, rotation) ref.translate(offset) @@ -1554,6 +1558,13 @@ class Pattern(PortList, AnnotatableImpl, Mirrorable): map_out = {out_port_name: next(iter(map_in.keys()))} self.check_ports(other.ports.keys(), map_in, map_out) + if not skip_geometry: + if append: + if isinstance(other, Abstract): + raise PatternError('Must provide a full `Pattern` (not an `Abstract`) when appending!') + elif isinstance(other, Pattern): + raise PatternError('Must provide an `Abstract` (not a `Pattern`) when creating a reference. ' + 'Use `append=True` if you intended to append the full geometry.') try: translation, rotation, pivot = self.find_transform( other, @@ -1587,10 +1598,6 @@ class Pattern(PortList, AnnotatableImpl, Mirrorable): self._log_port_removal(ki) map_out[vi] = None - if isinstance(other, Pattern) and not (append or skip_geometry): - raise PatternError('Must provide an `Abstract` (not a `Pattern`) when creating a reference. ' - 'Use `append=True` if you intended to append the full geometry.') - self.place( other, offset = translation, @@ -1662,9 +1669,13 @@ class Pattern(PortList, AnnotatableImpl, Mirrorable): else: raise PatternError(f'Unable to get ports from {type(source)}: {source}') - if port_map: + if port_map is not None: if isinstance(port_map, dict): missing_inkeys = set(port_map.keys()) - set(orig_ports.keys()) + port_targets = list(port_map.values()) + duplicate_targets = {vv for vv in port_targets if port_targets.count(vv) > 1} + if duplicate_targets: + raise PortError(f'Duplicate targets in `port_map`: {duplicate_targets}') mapped_ports = {port_map[k]: v for k, v in orig_ports.items() if k in port_map} else: port_set = set(port_map) diff --git a/masque/test/test_pattern.py b/masque/test/test_pattern.py index 59cc225..6048bf1 100644 --- a/masque/test/test_pattern.py +++ b/masque/test/test_pattern.py @@ -5,10 +5,11 @@ from numpy.testing import assert_equal, assert_allclose from numpy import pi from ..error import PatternError +from ..abstract import Abstract from ..pattern import Pattern from ..shapes import Polygon from ..ref import Ref -from ..ports import Port +from ..ports import Port, PortError from ..label import Label from ..repetition import Grid @@ -134,6 +135,18 @@ def test_pattern_place_requires_abstract_for_reference() -> None: with pytest.raises(PatternError, match='Must provide an `Abstract`'): parent.place(child) + assert not parent.ports + + +def test_pattern_place_append_requires_pattern_atomically() -> None: + parent = Pattern() + child = Abstract("child", {"A": Port((1, 2), 0)}) + + with pytest.raises(PatternError, match='Must provide a full `Pattern`'): + parent.place(child, append=True) + + assert not parent.ports + def test_pattern_interface() -> None: source = Pattern() @@ -151,6 +164,34 @@ def test_pattern_interface() -> None: assert iface.ports["out_A"].ptype == "test" +def test_pattern_interface_duplicate_port_map_targets_raise() -> None: + source = Pattern() + source.ports["A"] = Port((10, 20), 0) + source.ports["B"] = Port((30, 40), pi) + + with pytest.raises(PortError, match='Duplicate targets in `port_map`'): + Pattern.interface(source, port_map={"A": "X", "B": "X"}) + + +def test_pattern_interface_empty_port_map_copies_no_ports() -> None: + source = Pattern() + source.ports["A"] = Port((10, 20), 0) + source.ports["B"] = Port((30, 40), pi) + + assert not Pattern.interface(source, port_map={}).ports + assert not Pattern.interface(source, port_map=[]).ports + + +def test_pattern_plug_requires_abstract_for_reference_atomically() -> None: + parent = Pattern(ports={"X": Port((0, 0), 0)}) + child = Pattern(ports={"A": Port((0, 0), pi)}) + + with pytest.raises(PatternError, match='Must provide an `Abstract`'): + parent.plug(child, {"X": "A"}) + + assert set(parent.ports) == {"X"} + + def test_pattern_append_port_conflict_is_atomic() -> None: pat1 = Pattern() pat1.ports["A"] = Port((0, 0), 0)