From 338c123fb18fc7adccbfa4d758b294aad8adff12 Mon Sep 17 00:00:00 2001 From: jan Date: Sat, 7 Mar 2026 23:57:12 -0800 Subject: [PATCH] [pattern] speed up visualize() --- masque/pattern.py | 178 +++++++++++++++++++++++++--------- masque/test/test_visualize.py | 55 +++++++++++ 2 files changed, 186 insertions(+), 47 deletions(-) create mode 100644 masque/test/test_visualize.py diff --git a/masque/pattern.py b/masque/pattern.py index 12cae7f..9e15910 100644 --- a/masque/pattern.py +++ b/masque/pattern.py @@ -1061,12 +1061,13 @@ class Pattern(PortList, AnnotatableImpl, Mirrorable): klayout or a different GDS viewer! Args: - offset: Coordinates to offset by before drawing - line_color: Outlines are drawn with this color (passed to `matplotlib.collections.PolyCollection`) - fill_color: Interiors are drawn with this color (passed to `matplotlib.collections.PolyCollection`) - overdraw: Whether to create a new figure or draw on a pre-existing one + library: Mapping of {name: Pattern} for resolving references. Required if `self.has_refs()`. + offset: Coordinates to offset by before drawing. + line_color: Outlines are drawn with this color. + fill_color: Interiors are drawn with this color. + overdraw: Whether to create a new figure or draw on a pre-existing one. filename: If provided, save the figure to this file instead of showing it. - ports: If True, annotate the plot with arrows representing the ports + ports: If True, annotate the plot with arrows representing the ports. """ # TODO: add text labels to visualize() try: @@ -1080,8 +1081,113 @@ class Pattern(PortList, AnnotatableImpl, Mirrorable): if self.has_refs() and library is None: raise PatternError('Must provide a library when visualizing a pattern with refs') - offset = numpy.asarray(offset, dtype=float) + # Cache for {Pattern object ID: List of local polygon vertex arrays} + # Polygons are stored relative to the pattern's origin (offset included) + poly_cache: dict[int, list[NDArray[numpy.float64]]] = {} + def get_local_polys(pat: 'Pattern') -> list[NDArray[numpy.float64]]: + pid = id(pat) + if pid not in poly_cache: + polys = [] + for shape in chain.from_iterable(pat.shapes.values()): + for ss in shape.to_polygons(): + # Shape.to_polygons() returns Polygons with their own offsets and vertices. + # We need to expand any shape-level repetition here. + v_base = ss.vertices + ss.offset + if ss.repetition is not None: + for disp in ss.repetition.displacements: + polys.append(v_base + disp) + else: + polys.append(v_base) + poly_cache[pid] = polys + return poly_cache[pid] + + all_polygons: list[NDArray[numpy.float64]] = [] + port_info: list[tuple[str, NDArray[numpy.float64], float]] = [] + + def collect_polys_recursive( + pat: 'Pattern', + c_offset: NDArray[numpy.float64], + c_rotation: float, + c_mirrored: bool, + c_scale: float, + ) -> None: + # Current transform: T(c_offset) * R(c_rotation) * M(c_mirrored) * S(c_scale) + + # 1. Transform and collect local polygons + local_polys = get_local_polys(pat) + if local_polys: + rot_mat = rotation_matrix_2d(c_rotation) + for v in local_polys: + vt = v * c_scale + if c_mirrored: + vt = vt.copy() + vt[:, 1] *= -1 + vt = (rot_mat @ vt.T).T + c_offset + all_polygons.append(vt) + + # 2. Collect ports if requested + if ports: + for name, p in pat.ports.items(): + pt_v = p.offset * c_scale + if c_mirrored: + pt_v = pt_v.copy() + pt_v[1] *= -1 + pt_v = rotation_matrix_2d(c_rotation) @ pt_v + c_offset + + if p.rotation is not None: + pt_rot = p.rotation + if c_mirrored: + pt_rot = -pt_rot + pt_rot += c_rotation + port_info.append((name, pt_v, pt_rot)) + + # 3. Recurse into refs + for target, refs in pat.refs.items(): + if target is None: + continue + target_pat = library[target] + for ref in refs: + # Ref order of operations: mirror, rotate, scale, translate, repeat + + # Combined scale and mirror + r_scale = c_scale * ref.scale + r_mirrored = c_mirrored ^ ref.mirrored + + # Combined rotation: push c_mirrored and c_rotation through ref.rotation + r_rot_relative = -ref.rotation if c_mirrored else ref.rotation + r_rotation = c_rotation + r_rot_relative + + # Offset composition helper + def get_full_offset(rel_offset: NDArray[numpy.float64]) -> NDArray[numpy.float64]: + o = rel_offset * c_scale + if c_mirrored: + o = o.copy() + o[1] *= -1 + return rotation_matrix_2d(c_rotation) @ o + c_offset + + if ref.repetition is not None: + for disp in ref.repetition.displacements: + collect_polys_recursive( + target_pat, + get_full_offset(ref.offset + disp), + r_rotation, + r_mirrored, + r_scale + ) + else: + collect_polys_recursive( + target_pat, + get_full_offset(ref.offset), + r_rotation, + r_mirrored, + r_scale + ) + + # Start recursive collection + collect_polys_recursive(self, numpy.asarray(offset, dtype=float), 0.0, False, 1.0) + + # Plotting if not overdraw: figure = pyplot.figure() else: @@ -1089,50 +1195,28 @@ class Pattern(PortList, AnnotatableImpl, Mirrorable): axes = figure.gca() - polygons = [] - for shape in chain.from_iterable(self.shapes.values()): - for ss in shape.to_polygons(): - polygons.append(offset + ss.offset + ss.vertices) - - mpl_poly_collection = matplotlib.collections.PolyCollection( - polygons, - facecolors = fill_color, - edgecolors = line_color, - ) - axes.add_collection(mpl_poly_collection) + if all_polygons: + mpl_poly_collection = matplotlib.collections.PolyCollection( + all_polygons, + facecolors = fill_color, + edgecolors = line_color, + ) + axes.add_collection(mpl_poly_collection) if ports: - for port_name, port in self.ports.items(): - if port.rotation is not None: - p1 = offset + port.offset - angle = port.rotation - size = 1.0 # arrow size - p2 = p1 + size * numpy.array([numpy.cos(angle), numpy.sin(angle)]) + for port_name, pt_v, pt_rot in port_info: + p1 = pt_v + angle = pt_rot + size = 1.0 # arrow size + p2 = p1 + size * numpy.array([numpy.cos(angle), numpy.sin(angle)]) - axes.annotate( - port_name, - xy = tuple(p1), - xytext = tuple(p2), - arrowprops = dict(arrowstyle="->", color='g', linewidth=1), - color = 'g', - fontsize = 8, - ) - - for target, refs in self.refs.items(): - if target is None: - continue - if not refs: - continue - assert library is not None - target_pat = library[target] - for ref in refs: - ref.as_pattern(target_pat).visualize( - library = library, - offset = offset, - overdraw = True, - line_color = line_color, - fill_color = fill_color, - filename = filename, + axes.annotate( + port_name, + xy = tuple(p1), + xytext = tuple(p2), + arrowprops = dict(arrowstyle="->", color='g', linewidth=1), + color = 'g', + fontsize = 8, ) axes.autoscale_view() diff --git a/masque/test/test_visualize.py b/masque/test/test_visualize.py new file mode 100644 index 0000000..4dab435 --- /dev/null +++ b/masque/test/test_visualize.py @@ -0,0 +1,55 @@ +import numpy as np +import pytest +from masque.pattern import Pattern +from masque.ports import Port +from masque.repetition import Grid + +try: + import matplotlib + HAS_MATPLOTLIB = True +except ImportError: + HAS_MATPLOTLIB = False + +@pytest.mark.skipif(not HAS_MATPLOTLIB, reason="matplotlib not installed") +def test_visualize_noninteractive(tmp_path) -> None: + """ + Test that visualize() runs and saves a file without error. + This covers the recursive transformation and collection logic. + """ + # Create a hierarchy + child = Pattern() + child.polygon('L1', [[0, 0], [1, 0], [1, 1], [0, 1]]) + child.ports['P1'] = Port((0.5, 0.5), 0) + + parent = Pattern() + # Add some refs with various transforms + parent.ref('child', offset=(10, 0), rotation=np.pi/4, mirrored=True, scale=2.0) + + # Add a repetition + rep = Grid(a_vector=(5, 5), a_count=2) + parent.ref('child', offset=(0, 10), repetition=rep) + + library = {'child': child} + + output_file = tmp_path / "test_plot.png" + + # Run visualize with filename to avoid showing window + parent.visualize(library=library, filename=str(output_file), ports=True) + + assert output_file.exists() + assert output_file.stat().st_size > 0 + +@pytest.mark.skipif(not HAS_MATPLOTLIB, reason="matplotlib not installed") +def test_visualize_empty() -> None: + """ Test visualizing an empty pattern. """ + pat = Pattern() + # Should not raise + pat.visualize(overdraw=True) + +@pytest.mark.skipif(not HAS_MATPLOTLIB, reason="matplotlib not installed") +def test_visualize_no_refs() -> None: + """ Test visualizing a pattern with only local shapes (no library needed). """ + pat = Pattern() + pat.polygon('L1', [[0, 0], [1, 0], [0, 1]]) + # Should not raise even if library is None + pat.visualize(overdraw=True)