From aec708db331b4709b750f7758aa5e25350b4dc00 Mon Sep 17 00:00:00 2001 From: jan Date: Sat, 6 Apr 2024 13:01:47 -0700 Subject: [PATCH] add plugged() for manually-aligned ports --- masque/ports.py | 70 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/masque/ports.py b/masque/ports.py index 3e78759..41e6418 100644 --- a/masque/ports.py +++ b/masque/ports.py @@ -4,6 +4,7 @@ import traceback import logging from collections import Counter from abc import ABCMeta, abstractmethod +from itertools import chain import numpy from numpy import pi @@ -246,6 +247,75 @@ class PortList(metaclass=ABCMeta): self.ports.update(new_ports) return self + def plugged( + self, + connections: dict[str, str], + ) -> Self: + """ + Verify that the ports specified by `connections` are coincident and have opposing + rotations, then remove the ports. + + This is used when ports have been "manually" aligned as part of some other routing, + but for whatever reason were not eliminated via `plug()`. + + Args: + connections: Pairs of ports which "plug" each other (same offset, opposing directions) + + Returns: + self + + Raises: + `PortError` if the ports are not properly aligned. + """ + a_names, b_names = list(zip(*connections.items())) + a_ports = [self.ports[pp] for pp in a_names] + b_ports = [self.ports[pp] for pp in b_names] + + a_types = [pp.ptype for pp in a_ports] + b_types = [pp.ptype for pp in b_ports] + type_conflicts = numpy.array([at != bt and at != 'unk' and bt != 'unk' + for at, bt in zip(a_types, b_types)]) + + if type_conflicts.any(): + msg = 'Ports have conflicting types:\n' + for nn, (k, v) in enumerate(connections.items()): + if type_conflicts[nn]: + msg += f'{k} | {a_types[nn]}:{b_types[nn]} | {v}\n' + msg = ''.join(traceback.format_stack()) + '\n' + msg + warnings.warn(msg, stacklevel=2) + + a_offsets = numpy.array([pp.offset for pp in a_ports]) + b_offsets = numpy.array([pp.offset for pp in b_ports]) + a_rotations = numpy.array([pp.rotation if pp.rotation is not None else 0 for pp in a_ports]) + b_rotations = numpy.array([pp.rotation if pp.rotation is not None else 0 for pp in b_ports]) + a_has_rot = numpy.array([pp.rotation is not None for pp in a_ports], dtype=bool) + b_has_rot = numpy.array([pp.rotation is not None for pp in b_ports], dtype=bool) + has_rot = a_has_rot & b_has_rot + + if has_rot.any(): + rotations = numpy.mod(a_rotations - b_rotations - pi, 2 * pi) + rotations[~has_rot] = rotations[has_rot][0] + + if not numpy.allclose(rotations, 0): + rot_deg = numpy.rad2deg(rotations) + msg = 'Port orientations do not match:\n' + for nn, (k, v) in enumerate(connections.items()): + if not numpy.isclose(rot_deg[nn], 0): + msg += f'{k} | {rot_deg[nn]:g} | {v}\n' + raise PortError(msg) + + translations = a_offsets - b_offsets + if not numpy.allclose(translations, 0): + msg = 'Port translations do not match:\n' + for nn, (k, v) in enumerate(connections.items()): + if not numpy.allclose(translations[nn], 0): + msg += f'{k} | {translations[nn]} | {v}\n' + raise PortError(msg) + + for pp in chain(a_names, b_names): + del self.ports[pp] + return self + def check_ports( self, other_names: Iterable[str],