[pattern] speed up visualize()
This commit is contained in:
parent
a89f07c441
commit
338c123fb1
2 changed files with 188 additions and 49 deletions
|
|
@ -1061,12 +1061,13 @@ class Pattern(PortList, AnnotatableImpl, Mirrorable):
|
|||
klayout or a different GDS viewer!
|
||||
|
||||
Args:
|
||||
offset: Coordinates to offset by before drawing
|
||||
line_color: Outlines are drawn with this color (passed to `matplotlib.collections.PolyCollection`)
|
||||
fill_color: Interiors are drawn with this color (passed to `matplotlib.collections.PolyCollection`)
|
||||
overdraw: Whether to create a new figure or draw on a pre-existing one
|
||||
library: Mapping of {name: Pattern} for resolving references. Required if `self.has_refs()`.
|
||||
offset: Coordinates to offset by before drawing.
|
||||
line_color: Outlines are drawn with this color.
|
||||
fill_color: Interiors are drawn with this color.
|
||||
overdraw: Whether to create a new figure or draw on a pre-existing one.
|
||||
filename: If provided, save the figure to this file instead of showing it.
|
||||
ports: If True, annotate the plot with arrows representing the ports
|
||||
ports: If True, annotate the plot with arrows representing the ports.
|
||||
"""
|
||||
# TODO: add text labels to visualize()
|
||||
try:
|
||||
|
|
@ -1080,8 +1081,113 @@ class Pattern(PortList, AnnotatableImpl, Mirrorable):
|
|||
if self.has_refs() and library is None:
|
||||
raise PatternError('Must provide a library when visualizing a pattern with refs')
|
||||
|
||||
offset = numpy.asarray(offset, dtype=float)
|
||||
# Cache for {Pattern object ID: List of local polygon vertex arrays}
|
||||
# Polygons are stored relative to the pattern's origin (offset included)
|
||||
poly_cache: dict[int, list[NDArray[numpy.float64]]] = {}
|
||||
|
||||
def get_local_polys(pat: 'Pattern') -> list[NDArray[numpy.float64]]:
|
||||
pid = id(pat)
|
||||
if pid not in poly_cache:
|
||||
polys = []
|
||||
for shape in chain.from_iterable(pat.shapes.values()):
|
||||
for ss in shape.to_polygons():
|
||||
# Shape.to_polygons() returns Polygons with their own offsets and vertices.
|
||||
# We need to expand any shape-level repetition here.
|
||||
v_base = ss.vertices + ss.offset
|
||||
if ss.repetition is not None:
|
||||
for disp in ss.repetition.displacements:
|
||||
polys.append(v_base + disp)
|
||||
else:
|
||||
polys.append(v_base)
|
||||
poly_cache[pid] = polys
|
||||
return poly_cache[pid]
|
||||
|
||||
all_polygons: list[NDArray[numpy.float64]] = []
|
||||
port_info: list[tuple[str, NDArray[numpy.float64], float]] = []
|
||||
|
||||
def collect_polys_recursive(
|
||||
pat: 'Pattern',
|
||||
c_offset: NDArray[numpy.float64],
|
||||
c_rotation: float,
|
||||
c_mirrored: bool,
|
||||
c_scale: float,
|
||||
) -> None:
|
||||
# Current transform: T(c_offset) * R(c_rotation) * M(c_mirrored) * S(c_scale)
|
||||
|
||||
# 1. Transform and collect local polygons
|
||||
local_polys = get_local_polys(pat)
|
||||
if local_polys:
|
||||
rot_mat = rotation_matrix_2d(c_rotation)
|
||||
for v in local_polys:
|
||||
vt = v * c_scale
|
||||
if c_mirrored:
|
||||
vt = vt.copy()
|
||||
vt[:, 1] *= -1
|
||||
vt = (rot_mat @ vt.T).T + c_offset
|
||||
all_polygons.append(vt)
|
||||
|
||||
# 2. Collect ports if requested
|
||||
if ports:
|
||||
for name, p in pat.ports.items():
|
||||
pt_v = p.offset * c_scale
|
||||
if c_mirrored:
|
||||
pt_v = pt_v.copy()
|
||||
pt_v[1] *= -1
|
||||
pt_v = rotation_matrix_2d(c_rotation) @ pt_v + c_offset
|
||||
|
||||
if p.rotation is not None:
|
||||
pt_rot = p.rotation
|
||||
if c_mirrored:
|
||||
pt_rot = -pt_rot
|
||||
pt_rot += c_rotation
|
||||
port_info.append((name, pt_v, pt_rot))
|
||||
|
||||
# 3. Recurse into refs
|
||||
for target, refs in pat.refs.items():
|
||||
if target is None:
|
||||
continue
|
||||
target_pat = library[target]
|
||||
for ref in refs:
|
||||
# Ref order of operations: mirror, rotate, scale, translate, repeat
|
||||
|
||||
# Combined scale and mirror
|
||||
r_scale = c_scale * ref.scale
|
||||
r_mirrored = c_mirrored ^ ref.mirrored
|
||||
|
||||
# Combined rotation: push c_mirrored and c_rotation through ref.rotation
|
||||
r_rot_relative = -ref.rotation if c_mirrored else ref.rotation
|
||||
r_rotation = c_rotation + r_rot_relative
|
||||
|
||||
# Offset composition helper
|
||||
def get_full_offset(rel_offset: NDArray[numpy.float64]) -> NDArray[numpy.float64]:
|
||||
o = rel_offset * c_scale
|
||||
if c_mirrored:
|
||||
o = o.copy()
|
||||
o[1] *= -1
|
||||
return rotation_matrix_2d(c_rotation) @ o + c_offset
|
||||
|
||||
if ref.repetition is not None:
|
||||
for disp in ref.repetition.displacements:
|
||||
collect_polys_recursive(
|
||||
target_pat,
|
||||
get_full_offset(ref.offset + disp),
|
||||
r_rotation,
|
||||
r_mirrored,
|
||||
r_scale
|
||||
)
|
||||
else:
|
||||
collect_polys_recursive(
|
||||
target_pat,
|
||||
get_full_offset(ref.offset),
|
||||
r_rotation,
|
||||
r_mirrored,
|
||||
r_scale
|
||||
)
|
||||
|
||||
# Start recursive collection
|
||||
collect_polys_recursive(self, numpy.asarray(offset, dtype=float), 0.0, False, 1.0)
|
||||
|
||||
# Plotting
|
||||
if not overdraw:
|
||||
figure = pyplot.figure()
|
||||
else:
|
||||
|
|
@ -1089,23 +1195,18 @@ class Pattern(PortList, AnnotatableImpl, Mirrorable):
|
|||
|
||||
axes = figure.gca()
|
||||
|
||||
polygons = []
|
||||
for shape in chain.from_iterable(self.shapes.values()):
|
||||
for ss in shape.to_polygons():
|
||||
polygons.append(offset + ss.offset + ss.vertices)
|
||||
|
||||
if all_polygons:
|
||||
mpl_poly_collection = matplotlib.collections.PolyCollection(
|
||||
polygons,
|
||||
all_polygons,
|
||||
facecolors = fill_color,
|
||||
edgecolors = line_color,
|
||||
)
|
||||
axes.add_collection(mpl_poly_collection)
|
||||
|
||||
if ports:
|
||||
for port_name, port in self.ports.items():
|
||||
if port.rotation is not None:
|
||||
p1 = offset + port.offset
|
||||
angle = port.rotation
|
||||
for port_name, pt_v, pt_rot in port_info:
|
||||
p1 = pt_v
|
||||
angle = pt_rot
|
||||
size = 1.0 # arrow size
|
||||
p2 = p1 + size * numpy.array([numpy.cos(angle), numpy.sin(angle)])
|
||||
|
||||
|
|
@ -1118,23 +1219,6 @@ class Pattern(PortList, AnnotatableImpl, Mirrorable):
|
|||
fontsize = 8,
|
||||
)
|
||||
|
||||
for target, refs in self.refs.items():
|
||||
if target is None:
|
||||
continue
|
||||
if not refs:
|
||||
continue
|
||||
assert library is not None
|
||||
target_pat = library[target]
|
||||
for ref in refs:
|
||||
ref.as_pattern(target_pat).visualize(
|
||||
library = library,
|
||||
offset = offset,
|
||||
overdraw = True,
|
||||
line_color = line_color,
|
||||
fill_color = fill_color,
|
||||
filename = filename,
|
||||
)
|
||||
|
||||
axes.autoscale_view()
|
||||
axes.set_aspect('equal')
|
||||
|
||||
|
|
|
|||
55
masque/test/test_visualize.py
Normal file
55
masque/test/test_visualize.py
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
from masque.pattern import Pattern
|
||||
from masque.ports import Port
|
||||
from masque.repetition import Grid
|
||||
|
||||
try:
|
||||
import matplotlib
|
||||
HAS_MATPLOTLIB = True
|
||||
except ImportError:
|
||||
HAS_MATPLOTLIB = False
|
||||
|
||||
@pytest.mark.skipif(not HAS_MATPLOTLIB, reason="matplotlib not installed")
|
||||
def test_visualize_noninteractive(tmp_path) -> None:
|
||||
"""
|
||||
Test that visualize() runs and saves a file without error.
|
||||
This covers the recursive transformation and collection logic.
|
||||
"""
|
||||
# Create a hierarchy
|
||||
child = Pattern()
|
||||
child.polygon('L1', [[0, 0], [1, 0], [1, 1], [0, 1]])
|
||||
child.ports['P1'] = Port((0.5, 0.5), 0)
|
||||
|
||||
parent = Pattern()
|
||||
# Add some refs with various transforms
|
||||
parent.ref('child', offset=(10, 0), rotation=np.pi/4, mirrored=True, scale=2.0)
|
||||
|
||||
# Add a repetition
|
||||
rep = Grid(a_vector=(5, 5), a_count=2)
|
||||
parent.ref('child', offset=(0, 10), repetition=rep)
|
||||
|
||||
library = {'child': child}
|
||||
|
||||
output_file = tmp_path / "test_plot.png"
|
||||
|
||||
# Run visualize with filename to avoid showing window
|
||||
parent.visualize(library=library, filename=str(output_file), ports=True)
|
||||
|
||||
assert output_file.exists()
|
||||
assert output_file.stat().st_size > 0
|
||||
|
||||
@pytest.mark.skipif(not HAS_MATPLOTLIB, reason="matplotlib not installed")
|
||||
def test_visualize_empty() -> None:
|
||||
""" Test visualizing an empty pattern. """
|
||||
pat = Pattern()
|
||||
# Should not raise
|
||||
pat.visualize(overdraw=True)
|
||||
|
||||
@pytest.mark.skipif(not HAS_MATPLOTLIB, reason="matplotlib not installed")
|
||||
def test_visualize_no_refs() -> None:
|
||||
""" Test visualizing a pattern with only local shapes (no library needed). """
|
||||
pat = Pattern()
|
||||
pat.polygon('L1', [[0, 0], [1, 0], [0, 1]])
|
||||
# Should not raise even if library is None
|
||||
pat.visualize(overdraw=True)
|
||||
Loading…
Add table
Add a link
Reference in a new issue