[Pattern] improve atomicity of place(), plug(), interface()

This commit is contained in:
Jan Petykiewicz 2026-03-31 23:00:35 -07:00
commit 9767ee4e62
2 changed files with 63 additions and 11 deletions

View file

@ -1383,6 +1383,15 @@ class Pattern(PortList, AnnotatableImpl, Mirrorable):
if not skip_port_check:
self.check_ports(other.ports.keys(), map_in=None, map_out=port_map)
if not skip_geometry:
if append:
if isinstance(other, Abstract):
raise PatternError('Must provide a full `Pattern` (not an `Abstract`) when appending!')
else:
if isinstance(other, Pattern):
raise PatternError('Must provide an `Abstract` (not a `Pattern`) when creating a reference. '
'Use `append=True` if you intended to append the full geometry.')
ports = {}
for name, port in other.ports.items():
new_name = port_map.get(name, name)
@ -1404,8 +1413,6 @@ class Pattern(PortList, AnnotatableImpl, Mirrorable):
return self
if append:
if isinstance(other, Abstract):
raise PatternError('Must provide a full `Pattern` (not an `Abstract`) when appending!')
other_copy = other.deepcopy()
other_copy.ports.clear()
if mirrored:
@ -1414,9 +1421,6 @@ class Pattern(PortList, AnnotatableImpl, Mirrorable):
other_copy.translate_elements(offset)
self.append(other_copy)
else:
if isinstance(other, Pattern):
raise PatternError('Must provide an `Abstract` (not a `Pattern`) when creating a reference. '
'Use `append=True` if you intended to append the full geometry.')
ref = Ref(mirrored=mirrored)
ref.rotate_around(pivot, rotation)
ref.translate(offset)
@ -1554,6 +1558,13 @@ class Pattern(PortList, AnnotatableImpl, Mirrorable):
map_out = {out_port_name: next(iter(map_in.keys()))}
self.check_ports(other.ports.keys(), map_in, map_out)
if not skip_geometry:
if append:
if isinstance(other, Abstract):
raise PatternError('Must provide a full `Pattern` (not an `Abstract`) when appending!')
elif isinstance(other, Pattern):
raise PatternError('Must provide an `Abstract` (not a `Pattern`) when creating a reference. '
'Use `append=True` if you intended to append the full geometry.')
try:
translation, rotation, pivot = self.find_transform(
other,
@ -1587,10 +1598,6 @@ class Pattern(PortList, AnnotatableImpl, Mirrorable):
self._log_port_removal(ki)
map_out[vi] = None
if isinstance(other, Pattern) and not (append or skip_geometry):
raise PatternError('Must provide an `Abstract` (not a `Pattern`) when creating a reference. '
'Use `append=True` if you intended to append the full geometry.')
self.place(
other,
offset = translation,
@ -1662,9 +1669,13 @@ class Pattern(PortList, AnnotatableImpl, Mirrorable):
else:
raise PatternError(f'Unable to get ports from {type(source)}: {source}')
if port_map:
if port_map is not None:
if isinstance(port_map, dict):
missing_inkeys = set(port_map.keys()) - set(orig_ports.keys())
port_targets = list(port_map.values())
duplicate_targets = {vv for vv in port_targets if port_targets.count(vv) > 1}
if duplicate_targets:
raise PortError(f'Duplicate targets in `port_map`: {duplicate_targets}')
mapped_ports = {port_map[k]: v for k, v in orig_ports.items() if k in port_map}
else:
port_set = set(port_map)

View file

@ -5,10 +5,11 @@ from numpy.testing import assert_equal, assert_allclose
from numpy import pi
from ..error import PatternError
from ..abstract import Abstract
from ..pattern import Pattern
from ..shapes import Polygon
from ..ref import Ref
from ..ports import Port
from ..ports import Port, PortError
from ..label import Label
from ..repetition import Grid
@ -134,6 +135,18 @@ def test_pattern_place_requires_abstract_for_reference() -> None:
with pytest.raises(PatternError, match='Must provide an `Abstract`'):
parent.place(child)
assert not parent.ports
def test_pattern_place_append_requires_pattern_atomically() -> None:
parent = Pattern()
child = Abstract("child", {"A": Port((1, 2), 0)})
with pytest.raises(PatternError, match='Must provide a full `Pattern`'):
parent.place(child, append=True)
assert not parent.ports
def test_pattern_interface() -> None:
source = Pattern()
@ -151,6 +164,34 @@ def test_pattern_interface() -> None:
assert iface.ports["out_A"].ptype == "test"
def test_pattern_interface_duplicate_port_map_targets_raise() -> None:
source = Pattern()
source.ports["A"] = Port((10, 20), 0)
source.ports["B"] = Port((30, 40), pi)
with pytest.raises(PortError, match='Duplicate targets in `port_map`'):
Pattern.interface(source, port_map={"A": "X", "B": "X"})
def test_pattern_interface_empty_port_map_copies_no_ports() -> None:
source = Pattern()
source.ports["A"] = Port((10, 20), 0)
source.ports["B"] = Port((30, 40), pi)
assert not Pattern.interface(source, port_map={}).ports
assert not Pattern.interface(source, port_map=[]).ports
def test_pattern_plug_requires_abstract_for_reference_atomically() -> None:
parent = Pattern(ports={"X": Port((0, 0), 0)})
child = Pattern(ports={"A": Port((0, 0), pi)})
with pytest.raises(PatternError, match='Must provide an `Abstract`'):
parent.plug(child, {"X": "A"})
assert set(parent.ports) == {"X"}
def test_pattern_append_port_conflict_is_atomic() -> None:
pat1 = Pattern()
pat1.ports["A"] = Port((0, 0), 0)