[pattern] speed up visualize()

This commit is contained in:
jan 2026-03-07 23:57:12 -08:00
commit 338c123fb1
2 changed files with 188 additions and 49 deletions

View file

@ -1061,12 +1061,13 @@ class Pattern(PortList, AnnotatableImpl, Mirrorable):
klayout or a different GDS viewer! klayout or a different GDS viewer!
Args: Args:
offset: Coordinates to offset by before drawing library: Mapping of {name: Pattern} for resolving references. Required if `self.has_refs()`.
line_color: Outlines are drawn with this color (passed to `matplotlib.collections.PolyCollection`) offset: Coordinates to offset by before drawing.
fill_color: Interiors are drawn with this color (passed to `matplotlib.collections.PolyCollection`) line_color: Outlines are drawn with this color.
overdraw: Whether to create a new figure or draw on a pre-existing one fill_color: Interiors are drawn with this color.
overdraw: Whether to create a new figure or draw on a pre-existing one.
filename: If provided, save the figure to this file instead of showing it. filename: If provided, save the figure to this file instead of showing it.
ports: If True, annotate the plot with arrows representing the ports ports: If True, annotate the plot with arrows representing the ports.
""" """
# TODO: add text labels to visualize() # TODO: add text labels to visualize()
try: try:
@ -1080,8 +1081,113 @@ class Pattern(PortList, AnnotatableImpl, Mirrorable):
if self.has_refs() and library is None: if self.has_refs() and library is None:
raise PatternError('Must provide a library when visualizing a pattern with refs') raise PatternError('Must provide a library when visualizing a pattern with refs')
offset = numpy.asarray(offset, dtype=float) # Cache for {Pattern object ID: List of local polygon vertex arrays}
# Polygons are stored relative to the pattern's origin (offset included)
poly_cache: dict[int, list[NDArray[numpy.float64]]] = {}
def get_local_polys(pat: 'Pattern') -> list[NDArray[numpy.float64]]:
pid = id(pat)
if pid not in poly_cache:
polys = []
for shape in chain.from_iterable(pat.shapes.values()):
for ss in shape.to_polygons():
# Shape.to_polygons() returns Polygons with their own offsets and vertices.
# We need to expand any shape-level repetition here.
v_base = ss.vertices + ss.offset
if ss.repetition is not None:
for disp in ss.repetition.displacements:
polys.append(v_base + disp)
else:
polys.append(v_base)
poly_cache[pid] = polys
return poly_cache[pid]
all_polygons: list[NDArray[numpy.float64]] = []
port_info: list[tuple[str, NDArray[numpy.float64], float]] = []
def collect_polys_recursive(
pat: 'Pattern',
c_offset: NDArray[numpy.float64],
c_rotation: float,
c_mirrored: bool,
c_scale: float,
) -> None:
# Current transform: T(c_offset) * R(c_rotation) * M(c_mirrored) * S(c_scale)
# 1. Transform and collect local polygons
local_polys = get_local_polys(pat)
if local_polys:
rot_mat = rotation_matrix_2d(c_rotation)
for v in local_polys:
vt = v * c_scale
if c_mirrored:
vt = vt.copy()
vt[:, 1] *= -1
vt = (rot_mat @ vt.T).T + c_offset
all_polygons.append(vt)
# 2. Collect ports if requested
if ports:
for name, p in pat.ports.items():
pt_v = p.offset * c_scale
if c_mirrored:
pt_v = pt_v.copy()
pt_v[1] *= -1
pt_v = rotation_matrix_2d(c_rotation) @ pt_v + c_offset
if p.rotation is not None:
pt_rot = p.rotation
if c_mirrored:
pt_rot = -pt_rot
pt_rot += c_rotation
port_info.append((name, pt_v, pt_rot))
# 3. Recurse into refs
for target, refs in pat.refs.items():
if target is None:
continue
target_pat = library[target]
for ref in refs:
# Ref order of operations: mirror, rotate, scale, translate, repeat
# Combined scale and mirror
r_scale = c_scale * ref.scale
r_mirrored = c_mirrored ^ ref.mirrored
# Combined rotation: push c_mirrored and c_rotation through ref.rotation
r_rot_relative = -ref.rotation if c_mirrored else ref.rotation
r_rotation = c_rotation + r_rot_relative
# Offset composition helper
def get_full_offset(rel_offset: NDArray[numpy.float64]) -> NDArray[numpy.float64]:
o = rel_offset * c_scale
if c_mirrored:
o = o.copy()
o[1] *= -1
return rotation_matrix_2d(c_rotation) @ o + c_offset
if ref.repetition is not None:
for disp in ref.repetition.displacements:
collect_polys_recursive(
target_pat,
get_full_offset(ref.offset + disp),
r_rotation,
r_mirrored,
r_scale
)
else:
collect_polys_recursive(
target_pat,
get_full_offset(ref.offset),
r_rotation,
r_mirrored,
r_scale
)
# Start recursive collection
collect_polys_recursive(self, numpy.asarray(offset, dtype=float), 0.0, False, 1.0)
# Plotting
if not overdraw: if not overdraw:
figure = pyplot.figure() figure = pyplot.figure()
else: else:
@ -1089,50 +1195,28 @@ class Pattern(PortList, AnnotatableImpl, Mirrorable):
axes = figure.gca() axes = figure.gca()
polygons = [] if all_polygons:
for shape in chain.from_iterable(self.shapes.values()): mpl_poly_collection = matplotlib.collections.PolyCollection(
for ss in shape.to_polygons(): all_polygons,
polygons.append(offset + ss.offset + ss.vertices) facecolors = fill_color,
edgecolors = line_color,
mpl_poly_collection = matplotlib.collections.PolyCollection( )
polygons, axes.add_collection(mpl_poly_collection)
facecolors = fill_color,
edgecolors = line_color,
)
axes.add_collection(mpl_poly_collection)
if ports: if ports:
for port_name, port in self.ports.items(): for port_name, pt_v, pt_rot in port_info:
if port.rotation is not None: p1 = pt_v
p1 = offset + port.offset angle = pt_rot
angle = port.rotation size = 1.0 # arrow size
size = 1.0 # arrow size p2 = p1 + size * numpy.array([numpy.cos(angle), numpy.sin(angle)])
p2 = p1 + size * numpy.array([numpy.cos(angle), numpy.sin(angle)])
axes.annotate( axes.annotate(
port_name, port_name,
xy = tuple(p1), xy = tuple(p1),
xytext = tuple(p2), xytext = tuple(p2),
arrowprops = dict(arrowstyle="->", color='g', linewidth=1), arrowprops = dict(arrowstyle="->", color='g', linewidth=1),
color = 'g', color = 'g',
fontsize = 8, fontsize = 8,
)
for target, refs in self.refs.items():
if target is None:
continue
if not refs:
continue
assert library is not None
target_pat = library[target]
for ref in refs:
ref.as_pattern(target_pat).visualize(
library = library,
offset = offset,
overdraw = True,
line_color = line_color,
fill_color = fill_color,
filename = filename,
) )
axes.autoscale_view() axes.autoscale_view()

View file

@ -0,0 +1,55 @@
import numpy as np
import pytest
from masque.pattern import Pattern
from masque.ports import Port
from masque.repetition import Grid
try:
import matplotlib
HAS_MATPLOTLIB = True
except ImportError:
HAS_MATPLOTLIB = False
@pytest.mark.skipif(not HAS_MATPLOTLIB, reason="matplotlib not installed")
def test_visualize_noninteractive(tmp_path) -> None:
"""
Test that visualize() runs and saves a file without error.
This covers the recursive transformation and collection logic.
"""
# Create a hierarchy
child = Pattern()
child.polygon('L1', [[0, 0], [1, 0], [1, 1], [0, 1]])
child.ports['P1'] = Port((0.5, 0.5), 0)
parent = Pattern()
# Add some refs with various transforms
parent.ref('child', offset=(10, 0), rotation=np.pi/4, mirrored=True, scale=2.0)
# Add a repetition
rep = Grid(a_vector=(5, 5), a_count=2)
parent.ref('child', offset=(0, 10), repetition=rep)
library = {'child': child}
output_file = tmp_path / "test_plot.png"
# Run visualize with filename to avoid showing window
parent.visualize(library=library, filename=str(output_file), ports=True)
assert output_file.exists()
assert output_file.stat().st_size > 0
@pytest.mark.skipif(not HAS_MATPLOTLIB, reason="matplotlib not installed")
def test_visualize_empty() -> None:
""" Test visualizing an empty pattern. """
pat = Pattern()
# Should not raise
pat.visualize(overdraw=True)
@pytest.mark.skipif(not HAS_MATPLOTLIB, reason="matplotlib not installed")
def test_visualize_no_refs() -> None:
""" Test visualizing a pattern with only local shapes (no library needed). """
pat = Pattern()
pat.polygon('L1', [[0, 0], [1, 0], [0, 1]])
# Should not raise even if library is None
pat.visualize(overdraw=True)