diff --git a/masque/builder/builder.py b/masque/builder/builder.py index b92c9ac..b8b00be 100644 --- a/masque/builder/builder.py +++ b/masque/builder/builder.py @@ -226,6 +226,7 @@ class Builder(PortList): inherit_name: bool = True, set_rotation: bool | None = None, append: bool = False, + ok_connections: Iterable[tuple[str, str]] = (), ) -> Self: """ Wrapper around `Pattern.plug` which allows a string for `other`. @@ -260,6 +261,11 @@ class Builder(PortList): append: If `True`, `other` is appended instead of being referenced. Note that this does not flatten `other`, so its refs will still be refs (now inside `self`). + ok_connections: Set of "allowed" ptype combinations. Identical + ptypes are always allowed to connect, as is `'unk'` with + any other ptypte. Non-allowed ptype connections will emit a + warning. Order is ignored, i.e. `(a, b)` is equivalent to + `(b, a)`. Returns: self @@ -293,6 +299,7 @@ class Builder(PortList): inherit_name=inherit_name, set_rotation=set_rotation, append=append, + ok_connections=ok_connections, ) return self diff --git a/masque/pattern.py b/masque/pattern.py index 1816762..0ae230d 100644 --- a/masque/pattern.py +++ b/masque/pattern.py @@ -1225,6 +1225,7 @@ class Pattern(PortList, AnnotatableImpl, Mirrorable): inherit_name: bool = True, set_rotation: bool | None = None, append: bool = False, + ok_connections: Iterable[tuple[str, str]] = (), ) -> Self: """ Instantiate or append a pattern into the current pattern, connecting @@ -1270,6 +1271,11 @@ class Pattern(PortList, AnnotatableImpl, Mirrorable): append: If `True`, `other` is appended instead of being referenced. Note that this does not flatten `other`, so its refs will still be refs (now inside `self`). + ok_connections: Set of "allowed" ptype combinations. Identical + ptypes are always allowed to connect, as is `'unk'` with + any other ptypte. Non-allowed ptype connections will emit a + warning. Order is ignored, i.e. `(a, b)` is equivalent to + `(b, a)`. Returns: self @@ -1300,6 +1306,7 @@ class Pattern(PortList, AnnotatableImpl, Mirrorable): map_in, mirrored=mirrored, set_rotation=set_rotation, + ok_connections=ok_connections, ) # get rid of plugged ports diff --git a/masque/ports.py b/masque/ports.py index 90ba729..e85b2d9 100644 --- a/masque/ports.py +++ b/masque/ports.py @@ -419,6 +419,7 @@ class PortList(metaclass=ABCMeta): *, mirrored: bool = False, set_rotation: bool | None = None, + ok_connections: Iterable[tuple[str, str]] = (), ) -> tuple[NDArray[numpy.float64], float, NDArray[numpy.float64]]: """ Given a device `other` and a mapping `map_in` specifying port connections, @@ -435,6 +436,11 @@ class PortList(metaclass=ABCMeta): port with `rotation=None`), `set_rotation` must be provided to indicate how much `other` should be rotated. Otherwise, `set_rotation` must remain `None`. + ok_connections: Set of "allowed" ptype combinations. Identical + ptypes are always allowed to connect, as is `'unk'` with + any other ptypte. Non-allowed ptype connections will emit a + warning. Order is ignored, i.e. `(a, b)` is equivalent to + `(b, a)`. Returns: - The (x, y) translation (performed last) @@ -451,6 +457,7 @@ class PortList(metaclass=ABCMeta): map_in=map_in, mirrored=mirrored, set_rotation=set_rotation, + ok_connections=ok_connections, ) @staticmethod @@ -461,6 +468,7 @@ class PortList(metaclass=ABCMeta): *, mirrored: bool = False, set_rotation: bool | None = None, + ok_connections: Iterable[tuple[str, str]] = (), ) -> tuple[NDArray[numpy.float64], float, NDArray[numpy.float64]]: """ Given two sets of ports (s_ports and o_ports) and a mapping `map_in` @@ -479,6 +487,11 @@ class PortList(metaclass=ABCMeta): port with `rotation=None`), `set_rotation` must be provided to indicate how much `o_ports` should be rotated. Otherwise, `set_rotation` must remain `None`. + ok_connections: Set of "allowed" ptype combinations. Identical + ptypes are always allowed to connect, as is `'unk'` with + any other ptypte. Non-allowed ptype connections will emit a + warning. Order is ignored, i.e. `(a, b)` is equivalent to + `(b, a)`. Returns: - The (x, y) translation (performed last) @@ -502,7 +515,8 @@ class PortList(metaclass=ABCMeta): o_offsets[:, 1] *= -1 o_rotations *= -1 - type_conflicts = numpy.array([st != ot and 'unk' not in (st, ot) + ok_pairs = {tuple(sorted(pair)) for pair in ok_connections if pair[0] != pair[1])} + type_conflicts = numpy.array([(st != ot) and ('unk' not in (st, ot)) and (tuple(sorted((st, ot))) not in ok_pairs) for st, ot in zip(s_types, o_types, strict=True)]) if type_conflicts.any(): msg = 'Ports have conflicting types:\n'