diff --git a/masque/ports.py b/masque/ports.py index 8d7b6a6..3a695fb 100644 --- a/masque/ports.py +++ b/masque/ports.py @@ -510,10 +510,19 @@ class PortList(metaclass=ABCMeta): if missing_invals: raise PortError(f'`map_in` values not present in other device: {missing_invals}') + map_in_counts = Counter(map_in.values()) + conflicts_in = {kk for kk, vv in map_in_counts.items() if vv > 1} + if conflicts_in: + raise PortError(f'Duplicate values in `map_in`: {conflicts_in}') + missing_outkeys = set(map_out.keys()) - other if missing_outkeys: raise PortError(f'`map_out` keys not present in other device: {missing_outkeys}') + connected_outkeys = set(map_out.keys()) & set(map_in.values()) + if connected_outkeys: + raise PortError(f'`map_out` keys conflict with connected ports: {connected_outkeys}') + orig_remaining = set(self.ports.keys()) - set(map_in.keys()) other_remaining = other - set(map_out.keys()) - set(map_in.values()) mapped_vals = set(map_out.values()) diff --git a/masque/test/test_ports.py b/masque/test/test_ports.py index 0f60809..4e7d097 100644 --- a/masque/test/test_ports.py +++ b/masque/test/test_ports.py @@ -4,6 +4,7 @@ from numpy import pi from ..ports import Port, PortList from ..error import PortError +from ..pattern import Pattern def test_port_init() -> None: @@ -227,3 +228,32 @@ def test_port_list_plugged_mismatch() -> None: pl = MyPorts() with pytest.raises(PortError): pl.plugged({"A": "B"}) + + +def test_port_list_check_ports_duplicate_map_in_values_raise() -> None: + class MyPorts(PortList): + def __init__(self) -> None: + self._ports = {"A": Port((0, 0), 0), "B": Port((0, 0), 0)} + + @property + def ports(self) -> dict[str, Port]: + return self._ports + + @ports.setter + def ports(self, val: dict[str, Port]) -> None: + self._ports = val + + pl = MyPorts() + with pytest.raises(PortError, match="Duplicate values in `map_in`"): + pl.check_ports({"X", "Y"}, map_in={"A": "X", "B": "X"}) + assert set(pl.ports) == {"A", "B"} + + +def test_pattern_plug_rejects_map_out_on_connected_ports_atomically() -> None: + host = Pattern(ports={"A": Port((0, 0), 0)}) + other = Pattern(ports={"X": Port((0, 0), pi), "Y": Port((5, 0), 0)}) + + with pytest.raises(PortError, match="`map_out` keys conflict with connected ports"): + host.plug(other, {"A": "X"}, map_out={"X": "renamed", "Y": "out"}, append=True) + + assert set(host.ports) == {"A"}