[pattern] speed up visualize()
This commit is contained in:
parent
a89f07c441
commit
338c123fb1
2 changed files with 188 additions and 49 deletions
|
|
@ -1061,12 +1061,13 @@ class Pattern(PortList, AnnotatableImpl, Mirrorable):
|
||||||
klayout or a different GDS viewer!
|
klayout or a different GDS viewer!
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
offset: Coordinates to offset by before drawing
|
library: Mapping of {name: Pattern} for resolving references. Required if `self.has_refs()`.
|
||||||
line_color: Outlines are drawn with this color (passed to `matplotlib.collections.PolyCollection`)
|
offset: Coordinates to offset by before drawing.
|
||||||
fill_color: Interiors are drawn with this color (passed to `matplotlib.collections.PolyCollection`)
|
line_color: Outlines are drawn with this color.
|
||||||
overdraw: Whether to create a new figure or draw on a pre-existing one
|
fill_color: Interiors are drawn with this color.
|
||||||
|
overdraw: Whether to create a new figure or draw on a pre-existing one.
|
||||||
filename: If provided, save the figure to this file instead of showing it.
|
filename: If provided, save the figure to this file instead of showing it.
|
||||||
ports: If True, annotate the plot with arrows representing the ports
|
ports: If True, annotate the plot with arrows representing the ports.
|
||||||
"""
|
"""
|
||||||
# TODO: add text labels to visualize()
|
# TODO: add text labels to visualize()
|
||||||
try:
|
try:
|
||||||
|
|
@ -1080,8 +1081,113 @@ class Pattern(PortList, AnnotatableImpl, Mirrorable):
|
||||||
if self.has_refs() and library is None:
|
if self.has_refs() and library is None:
|
||||||
raise PatternError('Must provide a library when visualizing a pattern with refs')
|
raise PatternError('Must provide a library when visualizing a pattern with refs')
|
||||||
|
|
||||||
offset = numpy.asarray(offset, dtype=float)
|
# Cache for {Pattern object ID: List of local polygon vertex arrays}
|
||||||
|
# Polygons are stored relative to the pattern's origin (offset included)
|
||||||
|
poly_cache: dict[int, list[NDArray[numpy.float64]]] = {}
|
||||||
|
|
||||||
|
def get_local_polys(pat: 'Pattern') -> list[NDArray[numpy.float64]]:
|
||||||
|
pid = id(pat)
|
||||||
|
if pid not in poly_cache:
|
||||||
|
polys = []
|
||||||
|
for shape in chain.from_iterable(pat.shapes.values()):
|
||||||
|
for ss in shape.to_polygons():
|
||||||
|
# Shape.to_polygons() returns Polygons with their own offsets and vertices.
|
||||||
|
# We need to expand any shape-level repetition here.
|
||||||
|
v_base = ss.vertices + ss.offset
|
||||||
|
if ss.repetition is not None:
|
||||||
|
for disp in ss.repetition.displacements:
|
||||||
|
polys.append(v_base + disp)
|
||||||
|
else:
|
||||||
|
polys.append(v_base)
|
||||||
|
poly_cache[pid] = polys
|
||||||
|
return poly_cache[pid]
|
||||||
|
|
||||||
|
all_polygons: list[NDArray[numpy.float64]] = []
|
||||||
|
port_info: list[tuple[str, NDArray[numpy.float64], float]] = []
|
||||||
|
|
||||||
|
def collect_polys_recursive(
|
||||||
|
pat: 'Pattern',
|
||||||
|
c_offset: NDArray[numpy.float64],
|
||||||
|
c_rotation: float,
|
||||||
|
c_mirrored: bool,
|
||||||
|
c_scale: float,
|
||||||
|
) -> None:
|
||||||
|
# Current transform: T(c_offset) * R(c_rotation) * M(c_mirrored) * S(c_scale)
|
||||||
|
|
||||||
|
# 1. Transform and collect local polygons
|
||||||
|
local_polys = get_local_polys(pat)
|
||||||
|
if local_polys:
|
||||||
|
rot_mat = rotation_matrix_2d(c_rotation)
|
||||||
|
for v in local_polys:
|
||||||
|
vt = v * c_scale
|
||||||
|
if c_mirrored:
|
||||||
|
vt = vt.copy()
|
||||||
|
vt[:, 1] *= -1
|
||||||
|
vt = (rot_mat @ vt.T).T + c_offset
|
||||||
|
all_polygons.append(vt)
|
||||||
|
|
||||||
|
# 2. Collect ports if requested
|
||||||
|
if ports:
|
||||||
|
for name, p in pat.ports.items():
|
||||||
|
pt_v = p.offset * c_scale
|
||||||
|
if c_mirrored:
|
||||||
|
pt_v = pt_v.copy()
|
||||||
|
pt_v[1] *= -1
|
||||||
|
pt_v = rotation_matrix_2d(c_rotation) @ pt_v + c_offset
|
||||||
|
|
||||||
|
if p.rotation is not None:
|
||||||
|
pt_rot = p.rotation
|
||||||
|
if c_mirrored:
|
||||||
|
pt_rot = -pt_rot
|
||||||
|
pt_rot += c_rotation
|
||||||
|
port_info.append((name, pt_v, pt_rot))
|
||||||
|
|
||||||
|
# 3. Recurse into refs
|
||||||
|
for target, refs in pat.refs.items():
|
||||||
|
if target is None:
|
||||||
|
continue
|
||||||
|
target_pat = library[target]
|
||||||
|
for ref in refs:
|
||||||
|
# Ref order of operations: mirror, rotate, scale, translate, repeat
|
||||||
|
|
||||||
|
# Combined scale and mirror
|
||||||
|
r_scale = c_scale * ref.scale
|
||||||
|
r_mirrored = c_mirrored ^ ref.mirrored
|
||||||
|
|
||||||
|
# Combined rotation: push c_mirrored and c_rotation through ref.rotation
|
||||||
|
r_rot_relative = -ref.rotation if c_mirrored else ref.rotation
|
||||||
|
r_rotation = c_rotation + r_rot_relative
|
||||||
|
|
||||||
|
# Offset composition helper
|
||||||
|
def get_full_offset(rel_offset: NDArray[numpy.float64]) -> NDArray[numpy.float64]:
|
||||||
|
o = rel_offset * c_scale
|
||||||
|
if c_mirrored:
|
||||||
|
o = o.copy()
|
||||||
|
o[1] *= -1
|
||||||
|
return rotation_matrix_2d(c_rotation) @ o + c_offset
|
||||||
|
|
||||||
|
if ref.repetition is not None:
|
||||||
|
for disp in ref.repetition.displacements:
|
||||||
|
collect_polys_recursive(
|
||||||
|
target_pat,
|
||||||
|
get_full_offset(ref.offset + disp),
|
||||||
|
r_rotation,
|
||||||
|
r_mirrored,
|
||||||
|
r_scale
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
collect_polys_recursive(
|
||||||
|
target_pat,
|
||||||
|
get_full_offset(ref.offset),
|
||||||
|
r_rotation,
|
||||||
|
r_mirrored,
|
||||||
|
r_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start recursive collection
|
||||||
|
collect_polys_recursive(self, numpy.asarray(offset, dtype=float), 0.0, False, 1.0)
|
||||||
|
|
||||||
|
# Plotting
|
||||||
if not overdraw:
|
if not overdraw:
|
||||||
figure = pyplot.figure()
|
figure = pyplot.figure()
|
||||||
else:
|
else:
|
||||||
|
|
@ -1089,23 +1195,18 @@ class Pattern(PortList, AnnotatableImpl, Mirrorable):
|
||||||
|
|
||||||
axes = figure.gca()
|
axes = figure.gca()
|
||||||
|
|
||||||
polygons = []
|
if all_polygons:
|
||||||
for shape in chain.from_iterable(self.shapes.values()):
|
|
||||||
for ss in shape.to_polygons():
|
|
||||||
polygons.append(offset + ss.offset + ss.vertices)
|
|
||||||
|
|
||||||
mpl_poly_collection = matplotlib.collections.PolyCollection(
|
mpl_poly_collection = matplotlib.collections.PolyCollection(
|
||||||
polygons,
|
all_polygons,
|
||||||
facecolors = fill_color,
|
facecolors = fill_color,
|
||||||
edgecolors = line_color,
|
edgecolors = line_color,
|
||||||
)
|
)
|
||||||
axes.add_collection(mpl_poly_collection)
|
axes.add_collection(mpl_poly_collection)
|
||||||
|
|
||||||
if ports:
|
if ports:
|
||||||
for port_name, port in self.ports.items():
|
for port_name, pt_v, pt_rot in port_info:
|
||||||
if port.rotation is not None:
|
p1 = pt_v
|
||||||
p1 = offset + port.offset
|
angle = pt_rot
|
||||||
angle = port.rotation
|
|
||||||
size = 1.0 # arrow size
|
size = 1.0 # arrow size
|
||||||
p2 = p1 + size * numpy.array([numpy.cos(angle), numpy.sin(angle)])
|
p2 = p1 + size * numpy.array([numpy.cos(angle), numpy.sin(angle)])
|
||||||
|
|
||||||
|
|
@ -1118,23 +1219,6 @@ class Pattern(PortList, AnnotatableImpl, Mirrorable):
|
||||||
fontsize = 8,
|
fontsize = 8,
|
||||||
)
|
)
|
||||||
|
|
||||||
for target, refs in self.refs.items():
|
|
||||||
if target is None:
|
|
||||||
continue
|
|
||||||
if not refs:
|
|
||||||
continue
|
|
||||||
assert library is not None
|
|
||||||
target_pat = library[target]
|
|
||||||
for ref in refs:
|
|
||||||
ref.as_pattern(target_pat).visualize(
|
|
||||||
library = library,
|
|
||||||
offset = offset,
|
|
||||||
overdraw = True,
|
|
||||||
line_color = line_color,
|
|
||||||
fill_color = fill_color,
|
|
||||||
filename = filename,
|
|
||||||
)
|
|
||||||
|
|
||||||
axes.autoscale_view()
|
axes.autoscale_view()
|
||||||
axes.set_aspect('equal')
|
axes.set_aspect('equal')
|
||||||
|
|
||||||
|
|
|
||||||
55
masque/test/test_visualize.py
Normal file
55
masque/test/test_visualize.py
Normal file
|
|
@ -0,0 +1,55 @@
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from masque.pattern import Pattern
|
||||||
|
from masque.ports import Port
|
||||||
|
from masque.repetition import Grid
|
||||||
|
|
||||||
|
try:
|
||||||
|
import matplotlib
|
||||||
|
HAS_MATPLOTLIB = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_MATPLOTLIB = False
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not HAS_MATPLOTLIB, reason="matplotlib not installed")
|
||||||
|
def test_visualize_noninteractive(tmp_path) -> None:
|
||||||
|
"""
|
||||||
|
Test that visualize() runs and saves a file without error.
|
||||||
|
This covers the recursive transformation and collection logic.
|
||||||
|
"""
|
||||||
|
# Create a hierarchy
|
||||||
|
child = Pattern()
|
||||||
|
child.polygon('L1', [[0, 0], [1, 0], [1, 1], [0, 1]])
|
||||||
|
child.ports['P1'] = Port((0.5, 0.5), 0)
|
||||||
|
|
||||||
|
parent = Pattern()
|
||||||
|
# Add some refs with various transforms
|
||||||
|
parent.ref('child', offset=(10, 0), rotation=np.pi/4, mirrored=True, scale=2.0)
|
||||||
|
|
||||||
|
# Add a repetition
|
||||||
|
rep = Grid(a_vector=(5, 5), a_count=2)
|
||||||
|
parent.ref('child', offset=(0, 10), repetition=rep)
|
||||||
|
|
||||||
|
library = {'child': child}
|
||||||
|
|
||||||
|
output_file = tmp_path / "test_plot.png"
|
||||||
|
|
||||||
|
# Run visualize with filename to avoid showing window
|
||||||
|
parent.visualize(library=library, filename=str(output_file), ports=True)
|
||||||
|
|
||||||
|
assert output_file.exists()
|
||||||
|
assert output_file.stat().st_size > 0
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not HAS_MATPLOTLIB, reason="matplotlib not installed")
|
||||||
|
def test_visualize_empty() -> None:
|
||||||
|
""" Test visualizing an empty pattern. """
|
||||||
|
pat = Pattern()
|
||||||
|
# Should not raise
|
||||||
|
pat.visualize(overdraw=True)
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not HAS_MATPLOTLIB, reason="matplotlib not installed")
|
||||||
|
def test_visualize_no_refs() -> None:
|
||||||
|
""" Test visualizing a pattern with only local shapes (no library needed). """
|
||||||
|
pat = Pattern()
|
||||||
|
pat.polygon('L1', [[0, 0], [1, 0], [0, 1]])
|
||||||
|
# Should not raise even if library is None
|
||||||
|
pat.visualize(overdraw=True)
|
||||||
Loading…
Add table
Add a link
Reference in a new issue