diff --git a/masque/pattern.py b/masque/pattern.py index 4c72fdf..6d73c4b 100644 --- a/masque/pattern.py +++ b/masque/pattern.py @@ -1049,6 +1049,8 @@ class Pattern(PortList, AnnotatableImpl, Mirrorable): line_color: str = 'k', fill_color: str = 'none', overdraw: bool = False, + filename: str | None = None, + ports: bool = False, ) -> None: """ Draw a picture of the Pattern and wait for the user to inspect it @@ -1063,6 +1065,8 @@ class Pattern(PortList, AnnotatableImpl, Mirrorable): line_color: Outlines are drawn with this color (passed to `matplotlib.collections.PolyCollection`) fill_color: Interiors are drawn with this color (passed to `matplotlib.collections.PolyCollection`) overdraw: Whether to create a new figure or draw on a pre-existing one + filename: If provided, save the figure to this file instead of showing it. + ports: If True, annotate the plot with arrows representing the ports """ # TODO: add text labels to visualize() try: @@ -1080,7 +1084,6 @@ class Pattern(PortList, AnnotatableImpl, Mirrorable): if not overdraw: figure = pyplot.figure() - pyplot.axis('equal') else: figure = pyplot.gcf() @@ -1088,15 +1091,34 @@ class Pattern(PortList, AnnotatableImpl, Mirrorable): polygons = [] for shape in chain.from_iterable(self.shapes.values()): - polygons += [offset + s.offset + s.vertices for s in shape.to_polygons()] + for ss in shape.to_polygons(): + polygons.append(offset + ss.offset + ss.vertices) mpl_poly_collection = matplotlib.collections.PolyCollection( polygons, - facecolors=fill_color, - edgecolors=line_color, + facecolors = fill_color, + edgecolors = line_color, ) axes.add_collection(mpl_poly_collection) - pyplot.axis('equal') + + if ports: + for port_name, port in self.ports.items(): + if port.rotation is not None: + p1 = offset + port.offset + angle = port.rotation + size = 1.0 # Arrow size based on bounds or fixed + dx = size * numpy.cos(angle) + dy = size * numpy.sin(angle) + p2 = p1 + numpy.array([dx, dy]) + + axes.annotate( + port_name, + xy = tuple(p1), + xytext = tuple(p2), + arrowprops = dict(arrowstyle="->", color='g', linewidth=1), + color = 'g', + fontsize = 8, + ) for target, refs in self.refs.items(): if target is None: @@ -1107,17 +1129,24 @@ class Pattern(PortList, AnnotatableImpl, Mirrorable): target_pat = library[target] for ref in refs: ref.as_pattern(target_pat).visualize( - library=library, - offset=offset, - overdraw=True, - line_color=line_color, - fill_color=fill_color, + library = library, + offset = offset, + overdraw = True, + line_color = line_color, + fill_color = fill_color, + filename = filename, ) + axes.autoscale_view() + axes.set_aspect('equal') + if not overdraw: - pyplot.xlabel('x') - pyplot.ylabel('y') - pyplot.show() + axes.set_xlabel('x') + axes.set_ylabel('y') + if filename: + figure.savefig(filename) + else: + figure.show() # @overload # def place(