diff --git a/masque/utils/ports2data.py b/masque/utils/ports2data.py index 8dba890..b634543 100644 --- a/masque/utils/ports2data.py +++ b/masque/utils/ports2data.py @@ -6,7 +6,7 @@ and retrieving it (`data_to_ports`). the port locations. This particular approach is just a sensible default; feel free to to write equivalent functions for your own format or alternate storage methods. """ -from typing import Sequence, Optional, Mapping +from typing import Sequence, Optional, Mapping, Union import logging import numpy @@ -16,13 +16,13 @@ from ..label import Label from ..utils import layer_t from ..ports import Port from ..error import PatternError -from ..library import Library, WrapROLibrary +from ..library import Library, WrapROLibrary, Tree logger = logging.getLogger(__name__) -def ports_to_data(pattern: Pattern, layer: layer_t) -> Pattern: +def ports_to_data(pattern_or_tree: Union[Pattern, Tree], layer: layer_t) -> Union[Pattern, Tree]: """ Place a text label at each port location, specifying the port data in the format 'name:ptype angle_deg' @@ -39,6 +39,11 @@ def ports_to_data(pattern: Pattern, layer: layer_t) -> Pattern: Returns: `pattern` """ + pattern: Pattern + if isinstance(pattern_or_tree, Tree): + pattern = pattern_or_tree[pattern_or_tree.top] + else: + pattern = pattern_or_tree for name, port in pattern.ports.items(): if port.rotation is None: angle_deg = numpy.inf @@ -47,7 +52,7 @@ def ports_to_data(pattern: Pattern, layer: layer_t) -> Pattern: pattern.labels += [ Label(string=f'{name}:{port.ptype} {angle_deg:g}', layer=layer, offset=port.offset) ] - return pattern + return pattern_or_tree def data_to_ports(