diff --git a/masque/ports.py b/masque/ports.py index 36c47c4..688ed7a 100644 --- a/masque/ports.py +++ b/masque/ports.py @@ -1,4 +1,4 @@ -from typing import Iterable, KeysView, ValuesView, overload, Self +from typing import Iterable, KeysView, ValuesView, overload, Self, Mapping import warnings import traceback import logging @@ -303,7 +303,48 @@ class PortList(metaclass=ABCMeta): """ s_ports = self[map_in.keys()] o_ports = other[map_in.values()] + return self.find_ptransform( + s_ports=s_ports, + o_ports=o_ports, + map_in=map_in, + mirrored=mirrored, + set_rotation=set_rotation, + ) + @staticmethod + def find_ptransform( # TODO needs better name + s_ports: Mapping[str, Port], + o_ports: Mapping[str, Port], + map_in: dict[str, str], + *, + mirrored: tuple[bool, bool] = (False, False), + set_rotation: bool | None = None, + ) -> tuple[NDArray[numpy.float64], float, NDArray[numpy.float64]]: + """ + Given two sets of ports (s_ports and o_ports) and a mapping `map_in` + specifying port connections, find the transform which will correctly + align the specified o_ports onto their respective s_ports. + + Args:t + s_ports: A list of stationary ports + o_ports: A list of ports which are to be moved/mirrored. + map_in: dict of `{'s_port': 'o_port'}` mappings, specifying + port connections. + mirrored: Mirrors `o_ports` across the x or y axes prior to + connecting any ports. + set_rotation: If the necessary rotation cannot be determined from + the ports being connected (i.e. all pairs have at least one + port with `rotation=None`), `set_rotation` must be provided + to indicate how much `o_ports` should be rotated. Otherwise, + `set_rotation` must remain `None`. + + Returns: + - The (x, y) translation (performed last) + - The rotation (radians, counterclockwise) + - The (x, y) pivot point for the rotation + + The rotation should be performed before the translation. + """ s_offsets = numpy.array([p.offset for p in s_ports.values()]) o_offsets = numpy.array([p.offset for p in o_ports.values()]) s_types = [p.ptype for p in s_ports.values()]