Compare commits

...

17 commits

Author SHA1 Message Date
75a9114709 [bezier] validate weights 2026-04-01 21:16:03 -07:00
df578d7764 [PolyCollection] copy vertex offsets when making normalized form 2026-04-01 21:15:44 -07:00
786716fc62 [preflight] document that preflight doesn't copy the library 2026-04-01 20:58:10 -07:00
a82365ec8c [svg] fix duplicate svg ids 2026-04-01 20:57:35 -07:00
28be89f047 [gdsii] make sure iterable is supported 2026-04-01 20:56:59 -07:00
afc49f945d [DeferredDict] add setdefault(), pop(), popitem(), copy() 2026-04-01 20:14:53 -07:00
ce46cc18dc [tmpfile] delete the temporary file if an error occurs 2026-04-01 20:12:24 -07:00
7c50f95fde [ILibrary] update docs for add() 2026-04-01 20:00:46 -07:00
ae314cce93 [ILibraryView] child_order shouldn't leak graphlib.CycleErrror 2026-04-01 19:59:59 -07:00
09a95a6608 [ILibraryView] fix assignment during dfs() 2026-04-01 19:57:29 -07:00
fbe138d443 [data_to_ports] warn on invalid angle 2026-04-01 19:22:16 -07:00
4b416745da [repetition.Grid] check for invalid displacements or counts 2026-04-01 19:21:47 -07:00
0830dce50c [data_to_ports] don't leave the pattern dirty if we error out part-way 2026-04-01 19:10:50 -07:00
ac87179da2 [data_to_ports] warn that repetitions are not not expanded 2026-04-01 19:01:47 -07:00
f0eea0382b [Ref] get_bounds_single shoudl ignore repetition 2026-04-01 19:00:59 -07:00
0c9b435e94 [PortList.check_ports] Check for duplicate map_in/map_out values 2026-04-01 19:00:19 -07:00
f461222852 [ILibrary.add] respect mutate_other=False even without duplicate keys 2026-04-01 18:58:01 -07:00
20 changed files with 479 additions and 17 deletions

View file

@ -637,6 +637,7 @@ def check_valid_names(
max_length: Max allowed length
"""
names = tuple(names)
allowed_chars = set(string.ascii_letters + string.digits + '_?$')
bad_chars = [

View file

@ -30,6 +30,21 @@ def _ref_to_svg_transform(ref) -> str:
return f'matrix({a:g} {b:g} {c:g} {d:g} {e:g} {f:g})'
def _make_svg_ids(names: Mapping[str, Pattern]) -> dict[str, str]:
svg_ids: dict[str, str] = {}
seen_ids: set[str] = set()
for name in names:
base_id = mangle_name(name)
svg_id = base_id
suffix = 1
while svg_id in seen_ids:
suffix += 1
svg_id = f'{base_id}_{suffix}'
seen_ids.add(svg_id)
svg_ids[name] = svg_id
return svg_ids
def writefile(
library: Mapping[str, Pattern],
top: str,
@ -81,10 +96,11 @@ def writefile(
# Create file
svg = svgwrite.Drawing(filename, profile='full', viewBox=viewbox_string,
debug=(not custom_attributes))
svg_ids = _make_svg_ids(library)
# Now create a group for each pattern and add in any Boundary and Use elements
for name, pat in library.items():
svg_group = svg.g(id=mangle_name(name), fill='blue', stroke='red')
svg_group = svg.g(id=svg_ids[name], fill='blue', stroke='red')
for layer, shapes in pat.shapes.items():
for shape in shapes:
@ -123,11 +139,11 @@ def writefile(
continue
for ref in refs:
transform = _ref_to_svg_transform(ref)
use = svg.use(href='#' + mangle_name(target), transform=transform)
use = svg.use(href='#' + svg_ids[target], transform=transform)
svg_group.add(use)
svg.defs.add(svg_group)
svg.add(svg.use(href='#' + mangle_name(top)))
svg.add(svg.use(href='#' + svg_ids[top]))
svg.save()

View file

@ -33,6 +33,12 @@ def preflight(
Run a standard set of useful operations and checks, usually done immediately prior
to writing to a file (or immediately after reading).
Note that this helper is not copy-isolating. When `sort=True`, it constructs a new
`Library` wrapper around the same `Pattern` objects after sorting them in place, so
later mutating preflight steps such as `prune_empty_patterns` and
`wrap_repeated_shapes` may still mutate caller-owned patterns. Callers that need
isolation should deep-copy the library before calling `preflight()`.
Args:
sort: Whether to sort the patterns based on their names, and optionaly sort the pattern contents.
Default True. Useful for reproducible builds.
@ -145,7 +151,11 @@ def tmpfile(path: str | pathlib.Path) -> Iterator[IO[bytes]]:
path = pathlib.Path(path)
suffixes = ''.join(path.suffixes)
with tempfile.NamedTemporaryFile(suffix=suffixes, delete=False) as tmp_stream:
yield tmp_stream
try:
yield tmp_stream
except Exception:
pathlib.Path(tmp_stream.name).unlink(missing_ok=True)
raise
try:
shutil.move(tmp_stream.name, path)

View file

@ -22,7 +22,7 @@ import copy
from pprint import pformat
from collections import defaultdict
from abc import ABCMeta, abstractmethod
from graphlib import TopologicalSorter
from graphlib import TopologicalSorter, CycleError
import numpy
from numpy.typing import ArrayLike, NDArray
@ -538,6 +538,7 @@ class ILibraryView(Mapping[str, 'Pattern'], metaclass=ABCMeta):
raise LibraryError('visit_* functions returned a new `Pattern` object'
' but no top-level name was provided in `hierarchy`')
del cast('ILibrary', self)[name]
cast('ILibrary', self)[name] = pattern
return self
@ -617,7 +618,13 @@ class ILibraryView(Mapping[str, 'Pattern'], metaclass=ABCMeta):
Return:
Topologically sorted list of pattern names.
"""
return cast('list[str]', list(TopologicalSorter(self.child_graph(dangling=dangling)).static_order()))
try:
return cast('list[str]', list(TopologicalSorter(self.child_graph(dangling=dangling)).static_order()))
except CycleError as exc:
cycle = exc.args[1] if len(exc.args) > 1 else None
if cycle is None:
raise LibraryError('Cycle found while building child order') from exc
raise LibraryError(f'Cycle found while building child order: {cycle}') from exc
def find_refs_local(
self,
@ -916,8 +923,8 @@ class ILibrary(ILibraryView, MutableMapping[str, 'Pattern'], metaclass=ABCMeta):
(default).
Returns:
A mapping of `{old_name: new_name}` for all `old_name`s in `other`. Unchanged
names map to themselves.
A mapping of `{old_name: new_name}` for all names in `other` which were
renamed while being added. Unchanged names are omitted.
Raises:
`LibraryError` if a duplicate name is encountered even after applying `rename_theirs()`.
@ -926,8 +933,13 @@ class ILibrary(ILibraryView, MutableMapping[str, 'Pattern'], metaclass=ABCMeta):
duplicates = set(self.keys()) & set(other.keys())
if not duplicates:
for key in other:
self._merge(key, other, key)
if mutate_other:
temp = other
else:
temp = Library(copy.deepcopy(dict(other)))
for key in temp:
self._merge(key, temp, key)
return {}
if mutate_other:

View file

@ -510,10 +510,19 @@ class PortList(metaclass=ABCMeta):
if missing_invals:
raise PortError(f'`map_in` values not present in other device: {missing_invals}')
map_in_counts = Counter(map_in.values())
conflicts_in = {kk for kk, vv in map_in_counts.items() if vv > 1}
if conflicts_in:
raise PortError(f'Duplicate values in `map_in`: {conflicts_in}')
missing_outkeys = set(map_out.keys()) - other
if missing_outkeys:
raise PortError(f'`map_out` keys not present in other device: {missing_outkeys}')
connected_outkeys = set(map_out.keys()) & set(map_in.values())
if connected_outkeys:
raise PortError(f'`map_out` keys conflict with connected ports: {connected_outkeys}')
orig_remaining = set(self.ports.keys()) - set(map_in.keys())
other_remaining = other - set(map_out.keys()) - set(map_in.values())
mapped_vals = set(map_out.values())

View file

@ -236,7 +236,10 @@ class Ref(
bounds = numpy.vstack((numpy.min(corners, axis=0),
numpy.max(corners, axis=0))) * self.scale + [self.offset]
return bounds
return self.as_pattern(pattern=pattern).get_bounds(library)
single_ref = self.deepcopy()
single_ref.repetition = None
return single_ref.as_pattern(pattern=pattern).get_bounds(library)
def __repr__(self) -> str:
rotation = f' r{numpy.rad2deg(self.rotation):g}' if self.rotation != 0 else ''

View file

@ -184,6 +184,8 @@ class Grid(Repetition):
def a_count(self, val: int) -> None:
if val != int(val):
raise PatternError('a_count must be convertable to an int!')
if int(val) < 1:
raise PatternError(f'Repetition has too-small a_count: {val}')
self._a_count = int(val)
# b_count property
@ -195,6 +197,8 @@ class Grid(Repetition):
def b_count(self, val: int) -> None:
if val != int(val):
raise PatternError('b_count must be convertable to an int!')
if int(val) < 1:
raise PatternError(f'Repetition has too-small b_count: {val}')
self._b_count = int(val)
@property
@ -325,7 +329,22 @@ class Arbitrary(Repetition):
@displacements.setter
def displacements(self, val: ArrayLike) -> None:
vala = numpy.array(val, dtype=float)
try:
vala = numpy.array(val, dtype=float)
except (TypeError, ValueError) as exc:
raise PatternError('displacements must be convertible to an Nx2 ndarray') from exc
if vala.size == 0:
self._displacements = numpy.empty((0, 2), dtype=float)
return
if vala.ndim == 1:
if vala.size != 2:
raise PatternError('displacements must be convertible to an Nx2 ndarray')
vala = vala.reshape(1, 2)
elif vala.ndim != 2 or vala.shape[1] != 2:
raise PatternError('displacements must be convertible to an Nx2 ndarray')
order = numpy.lexsort(vala.T[::-1]) # sortrows
self._displacements = vala[order]

View file

@ -219,7 +219,7 @@ class PolyCollection(Shape):
(offset, scale / norm_value, rotation, False),
lambda: PolyCollection(
vertex_lists=rotated_vertices * norm_value,
vertex_offsets=self._vertex_offsets,
vertex_offsets=self._vertex_offsets.copy(),
),
)

View file

@ -1,8 +1,10 @@
from pathlib import Path
from typing import cast
import numpy
import pytest
from numpy.testing import assert_equal, assert_allclose
from ..error import LibraryError
from ..pattern import Pattern
from ..library import Library
from ..file import gdsii
@ -69,3 +71,10 @@ def test_gdsii_annotations(tmp_path: Path) -> None:
read_ann = read_lib["cell"].shapes[(1, 0)][0].annotations
assert read_ann is not None
assert read_ann["1"] == ["hello"]
def test_gdsii_check_valid_names_validates_generator_lengths() -> None:
names = (name for name in ("a" * 40,))
with pytest.raises(LibraryError, match="invalid names"):
gdsii.check_valid_names(names)

View file

@ -221,6 +221,77 @@ def test_library_rename() -> None:
assert "old" not in lib["parent"].refs
def test_library_dfs_can_replace_existing_patterns() -> None:
lib = Library()
child = Pattern()
lib["child"] = child
top = Pattern()
top.ref("child")
lib["top"] = top
replacement_top = Pattern(ports={"T": Port((1, 2), 0)})
replacement_child = Pattern(ports={"C": Port((3, 4), 0)})
def visit_after(pattern: Pattern, hierarchy: tuple[str | None, ...], **kwargs) -> Pattern: # noqa: ARG001
if hierarchy[-1] == "child":
return replacement_child
if hierarchy[-1] == "top":
return replacement_top
return pattern
lib.dfs(lib["top"], visit_after=visit_after, hierarchy=("top",), transform=True)
assert lib["top"] is replacement_top
assert lib["child"] is replacement_child
def test_lazy_library_dfs_can_replace_existing_patterns() -> None:
lib = LazyLibrary()
lib["child"] = lambda: Pattern()
lib["top"] = lambda: Pattern(refs={"child": []})
top = lib["top"]
top.ref("child")
replacement_top = Pattern(ports={"T": Port((1, 2), 0)})
replacement_child = Pattern(ports={"C": Port((3, 4), 0)})
def visit_after(pattern: Pattern, hierarchy: tuple[str | None, ...], **kwargs) -> Pattern: # noqa: ARG001
if hierarchy[-1] == "child":
return replacement_child
if hierarchy[-1] == "top":
return replacement_top
return pattern
lib.dfs(top, visit_after=visit_after, hierarchy=("top",), transform=True)
assert lib["top"] is replacement_top
assert lib["child"] is replacement_child
def test_library_add_no_duplicates_respects_mutate_other_false() -> None:
src_pat = Pattern(ports={"A": Port((0, 0), 0)})
lib = Library({"a": Pattern()})
lib.add({"b": src_pat}, mutate_other=False)
assert lib["b"] is not src_pat
lib["b"].ports["A"].offset[0] = 123
assert tuple(src_pat.ports["A"].offset) == (0.0, 0.0)
def test_library_add_returns_only_renamed_entries() -> None:
lib = Library({"a": Pattern(), "_shape": Pattern()})
assert lib.add({"b": Pattern(), "c": Pattern()}, mutate_other=False) == {}
rename_map = lib.add({"_shape": Pattern(), "keep": Pattern()}, mutate_other=False)
assert set(rename_map) == {"_shape"}
assert rename_map["_shape"] != "_shape"
assert "keep" not in rename_map
def test_library_subtree() -> None:
lib = Library()
lib["a"] = Pattern()
@ -234,6 +305,26 @@ def test_library_subtree() -> None:
assert "c" not in sub
def test_library_child_order_cycle_raises_library_error() -> None:
lib = Library()
lib["a"] = Pattern()
lib["a"].ref("b")
lib["b"] = Pattern()
lib["b"].ref("a")
with pytest.raises(LibraryError, match="Cycle found while building child order"):
lib.child_order()
def test_library_find_refs_global_cycle_raises_library_error() -> None:
lib = Library()
lib["a"] = Pattern()
lib["a"].ref("a")
with pytest.raises(LibraryError, match="Cycle found while building child order"):
lib.find_refs_global("a")
def test_library_get_name() -> None:
lib = Library()
lib["cell"] = Pattern()

View file

@ -4,6 +4,7 @@ from numpy import pi
from ..ports import Port, PortList
from ..error import PortError
from ..pattern import Pattern
def test_port_init() -> None:
@ -227,3 +228,32 @@ def test_port_list_plugged_mismatch() -> None:
pl = MyPorts()
with pytest.raises(PortError):
pl.plugged({"A": "B"})
def test_port_list_check_ports_duplicate_map_in_values_raise() -> None:
class MyPorts(PortList):
def __init__(self) -> None:
self._ports = {"A": Port((0, 0), 0), "B": Port((0, 0), 0)}
@property
def ports(self) -> dict[str, Port]:
return self._ports
@ports.setter
def ports(self, val: dict[str, Port]) -> None:
self._ports = val
pl = MyPorts()
with pytest.raises(PortError, match="Duplicate values in `map_in`"):
pl.check_ports({"X", "Y"}, map_in={"A": "X", "B": "X"})
assert set(pl.ports) == {"A", "B"}
def test_pattern_plug_rejects_map_out_on_connected_ports_atomically() -> None:
host = Pattern(ports={"A": Port((0, 0), 0)})
other = Pattern(ports={"X": Port((0, 0), pi), "Y": Port((5, 0), 0)})
with pytest.raises(PortError, match="`map_out` keys conflict with connected ports"):
host.plug(other, {"A": "X"}, map_out={"X": "renamed", "Y": "out"}, append=True)
assert set(host.ports) == {"A"}

View file

@ -1,10 +1,13 @@
import numpy
import pytest
from numpy.testing import assert_allclose
from ..utils.ports2data import ports_to_data, data_to_ports
from ..pattern import Pattern
from ..ports import Port
from ..library import Library
from ..error import PortError
from ..repetition import Grid
def test_ports2data_roundtrip() -> None:
@ -74,3 +77,56 @@ def test_data_to_ports_hierarchical_scaled_ref() -> None:
assert_allclose(parent.ports["A"].offset, [100, 110], atol=1e-10)
assert parent.ports["A"].rotation is not None
assert_allclose(parent.ports["A"].rotation, numpy.pi / 2, atol=1e-10)
def test_data_to_ports_hierarchical_repeated_ref_warns_and_keeps_best_effort(
caplog: pytest.LogCaptureFixture,
) -> None:
lib = Library()
child = Pattern()
layer = (10, 0)
child.label(layer=layer, string="A:type1 0", offset=(5, 0))
lib["child"] = child
parent = Pattern()
parent.ref("child", repetition=Grid(a_vector=(100, 0), a_count=3))
caplog.set_level("WARNING")
data_to_ports([layer], lib, parent, max_depth=1)
assert "A" in parent.ports
assert_allclose(parent.ports["A"].offset, [5, 0], atol=1e-10)
assert any("importing only the base instance ports" in record.message for record in caplog.records)
def test_data_to_ports_hierarchical_collision_is_atomic() -> None:
lib = Library()
child = Pattern()
layer = (10, 0)
child.label(layer=layer, string="A:type1 0", offset=(5, 0))
lib["child"] = child
parent = Pattern()
parent.ref("child", offset=(0, 0))
parent.ref("child", offset=(10, 0))
with pytest.raises(PortError, match="Device ports conflict with existing ports"):
data_to_ports([layer], lib, parent, max_depth=1)
assert not parent.ports
def test_data_to_ports_flat_bad_angle_warns_and_skips(
caplog: pytest.LogCaptureFixture,
) -> None:
layer = (10, 0)
pat = Pattern()
pat.label(layer=layer, string="A:type1 nope", offset=(5, 0))
caplog.set_level("WARNING")
data_to_ports([layer], {}, pat)
assert not pat.ports
assert any('bad angle' in record.message for record in caplog.records)

View file

@ -64,6 +64,22 @@ def test_ref_get_bounds() -> None:
assert_equal(bounds, [[10, 10], [20, 20]])
def test_ref_get_bounds_single_ignores_repetition_for_non_manhattan_rotation() -> None:
sub_pat = Pattern()
sub_pat.rect((1, 0), xmin=0, xmax=1, ymin=0, ymax=2)
rep = Grid(a_vector=(5, 0), b_vector=(0, 7), a_count=3, b_count=2)
ref = Ref(offset=(10, 20), rotation=pi / 4, repetition=rep)
bounds = ref.get_bounds_single(sub_pat)
repeated_bounds = ref.get_bounds(sub_pat)
assert bounds is not None
assert repeated_bounds is not None
assert repeated_bounds[1, 0] > bounds[1, 0]
assert repeated_bounds[1, 1] > bounds[1, 1]
def test_ref_copy() -> None:
ref1 = Ref(offset=(1, 2), rotation=0.5, annotations={"a": [1]})
ref2 = ref1.copy()

View file

@ -1,7 +1,9 @@
import pytest
from numpy.testing import assert_equal, assert_allclose
from numpy import pi
from ..repetition import Grid, Arbitrary
from ..error import PatternError
def test_grid_displacements() -> None:
@ -51,6 +53,30 @@ def test_arbitrary_transform() -> None:
assert_allclose(arb.displacements, [[0, -10]], atol=1e-10)
def test_arbitrary_empty_repetition_is_allowed() -> None:
arb = Arbitrary([])
assert arb.displacements.shape == (0, 2)
assert arb.get_bounds() is None
def test_arbitrary_rejects_non_nx2_displacements() -> None:
for displacements in ([[1], [2]], [[1, 2, 3]], [1, 2, 3]):
with pytest.raises(PatternError, match='displacements must be convertible to an Nx2 ndarray'):
Arbitrary(displacements)
def test_grid_count_setters_reject_nonpositive_values() -> None:
for attr, value, message in (
('a_count', 0, 'a_count'),
('a_count', -1, 'a_count'),
('b_count', 0, 'b_count'),
('b_count', -1, 'b_count'),
):
grid = Grid(a_vector=(10, 0), b_vector=(0, 5), a_count=2, b_count=2)
with pytest.raises(PatternError, match=message):
setattr(grid, attr, value)
def test_repetition_less_equal_includes_equality() -> None:
grid_a = Grid(a_vector=(10, 0), a_count=2)
grid_b = Grid(a_vector=(10, 0), a_count=2)

View file

@ -212,3 +212,27 @@ def test_poly_collection_valid() -> None:
assert len(sorted_shapes) == 4
# Just verify it doesn't crash and is stable
assert sorted(sorted_shapes) == sorted_shapes
def test_poly_collection_normalized_form_reconstruction_is_independent() -> None:
pc = PolyCollection([[0, 0], [1, 0], [0, 1]], [0])
_intrinsic, _extrinsic, rebuild = pc.normalized_form(1)
clone = rebuild()
clone.vertex_offsets[:] = [5]
assert_equal(pc.vertex_offsets, [0])
assert_equal(clone.vertex_offsets, [5])
def test_poly_collection_normalized_form_rebuilds_independent_clones() -> None:
pc = PolyCollection([[0, 0], [1, 0], [0, 1]], [0])
_intrinsic, _extrinsic, rebuild = pc.normalized_form(1)
first = rebuild()
second = rebuild()
first.vertex_offsets[:] = [7]
assert_equal(first.vertex_offsets, [7])
assert_equal(second.vertex_offsets, [0])
assert_equal(pc.vertex_offsets, [0])

View file

@ -68,3 +68,31 @@ def test_svg_ref_mirroring_changes_affine_transform(tmp_path: Path) -> None:
assert_allclose(plain_transform, (0, 2, -2, 0, 3, 4), atol=1e-10)
assert_allclose(mirrored_transform, (0, 2, 2, 0, 3, 4), atol=1e-10)
def test_svg_uses_unique_ids_for_colliding_mangled_names(tmp_path: Path) -> None:
lib = Library()
first = Pattern()
first.polygon("1", vertices=[[0, 0], [1, 0], [0, 1]])
lib["a b"] = first
second = Pattern()
second.polygon("1", vertices=[[0, 0], [2, 0], [0, 2]])
lib["a-b"] = second
top = Pattern()
top.ref("a b")
top.ref("a-b", offset=(5, 0))
lib["top"] = top
svg_path = tmp_path / "colliding_ids.svg"
svg.writefile(lib, "top", str(svg_path))
root = ET.fromstring(svg_path.read_text())
ids = [group.attrib["id"] for group in root.iter(f"{SVG_NS}g")]
hrefs = [use.attrib[XLINK_HREF] for use in root.iter(f"{SVG_NS}use")]
assert ids.count("a_b") == 1
assert len(set(ids)) == len(ids)
assert "#a_b" in hrefs
assert "#a_b_2" in hrefs

View file

@ -1,8 +1,14 @@
from pathlib import Path
import numpy
from numpy.testing import assert_equal, assert_allclose
from numpy import pi
import pytest
from ..utils import remove_duplicate_vertices, remove_colinear_vertices, poly_contains_points, rotation_matrix_2d, apply_transforms, DeferredDict
from ..file.utils import tmpfile
from ..utils.curves import bezier
from ..error import PatternError
def test_remove_duplicate_vertices() -> None:
@ -88,6 +94,19 @@ def test_apply_transforms_advanced() -> None:
assert_allclose(combined[0], [0, 10, pi / 2, 1, 1], atol=1e-10)
def test_bezier_validates_weight_length() -> None:
with pytest.raises(PatternError, match='one entry per control point'):
bezier([[0, 0], [1, 1]], [0, 0.5, 1], weights=[1])
with pytest.raises(PatternError, match='one entry per control point'):
bezier([[0, 0], [1, 1]], [0, 0.5, 1], weights=[1, 2, 3])
def test_bezier_accepts_exact_weight_count() -> None:
samples = bezier([[0, 0], [1, 1]], [0, 0.5, 1], weights=[1, 2])
assert_allclose(samples, [[0, 0], [2 / 3, 2 / 3], [1, 1]], atol=1e-10)
def test_deferred_dict_accessors_resolve_values_once() -> None:
calls = 0
@ -104,3 +123,48 @@ def test_deferred_dict_accessors_resolve_values_once() -> None:
assert list(deferred.values()) == [7]
assert list(deferred.items()) == [("x", 7)]
assert calls == 1
def test_deferred_dict_mutating_accessors_preserve_value_semantics() -> None:
calls = 0
def make_value() -> int:
nonlocal calls
calls += 1
return 7
deferred = DeferredDict[str, int]()
assert deferred.setdefault("x", 5) == 5
assert deferred["x"] == 5
assert deferred.setdefault("y", make_value) == 7
assert deferred["y"] == 7
assert calls == 1
copy_deferred = deferred.copy()
assert isinstance(copy_deferred, DeferredDict)
assert copy_deferred["x"] == 5
assert copy_deferred["y"] == 7
assert calls == 1
assert deferred.pop("x") == 5
assert deferred.pop("missing", 9) == 9
assert deferred.popitem() == ("y", 7)
def test_tmpfile_cleans_up_on_exception(tmp_path: Path) -> None:
target = tmp_path / "out.txt"
temp_path = None
try:
with tmpfile(target) as stream:
temp_path = Path(stream.name)
stream.write(b"hello")
raise RuntimeError("boom")
except RuntimeError:
pass
assert temp_path is not None
assert not target.exists()
assert not temp_path.exists()

View file

@ -2,6 +2,8 @@ import numpy
from numpy.typing import ArrayLike, NDArray
from numpy import pi
from ..error import PatternError
try:
from numpy import trapezoid
except ImportError:
@ -31,6 +33,11 @@ def bezier(
tt = numpy.asarray(tt)
nn = nodes.shape[0]
weights = numpy.ones(nn) if weights is None else numpy.asarray(weights)
if weights.ndim != 1 or weights.shape[0] != nn:
raise PatternError(
f'weights must be a 1D array with one entry per control point; '
f'got shape {weights.shape} for {nn} control points'
)
with numpy.errstate(divide='ignore'):
umul = (tt / (1 - tt)).clip(max=1)

View file

@ -5,6 +5,7 @@ from functools import lru_cache
Key = TypeVar('Key')
Value = TypeVar('Value')
_MISSING = object()
class DeferredDict(dict, Generic[Key, Value]):
@ -46,6 +47,15 @@ class DeferredDict(dict, Generic[Key, Value]):
return default
return self[key]
def setdefault(self, key: Key, default: Value | Callable[[], Value] | None = None) -> Value | None:
if key in self:
return self[key]
if callable(default):
self[key] = default
else:
self.set_const(key, default)
return self[key]
def items(self) -> Iterator[tuple[Key, Value]]:
for key in self.keys():
yield key, self[key]
@ -65,6 +75,25 @@ class DeferredDict(dict, Generic[Key, Value]):
else:
self.set_const(k, v)
def pop(self, key: Key, default: Value | object = _MISSING) -> Value:
if key in self:
value = self[key]
dict.pop(self, key)
return value
if default is _MISSING:
raise KeyError(key)
return default # type: ignore[return-value]
def popitem(self) -> tuple[Key, Value]:
key, value_func = dict.popitem(self)
return key, value_func()
def copy(self) -> 'DeferredDict[Key, Value]':
new = DeferredDict[Key, Value]()
for key in self.keys():
dict.__setitem__(new, key, dict.__getitem__(self, key))
return new
def __repr__(self) -> str:
return '<DeferredDict with keys ' + repr(set(self.keys())) + '>'

View file

@ -122,6 +122,7 @@ def data_to_ports(
if not found_ports:
return pattern
imported_ports: dict[str, Port] = {}
for target, refs in pattern.refs.items():
if target is None:
continue
@ -133,9 +134,14 @@ def data_to_ports(
if not aa.ports:
break
if ref.repetition is not None:
logger.warning(f'Pattern {name if name else pattern} has repeated ref to {target!r}; '
'data_to_ports() is importing only the base instance ports')
aa.apply_ref_transform(ref)
pattern.check_ports(other_names=aa.ports.keys())
pattern.ports.update(aa.ports)
Pattern(ports={**pattern.ports, **imported_ports}).check_ports(other_names=aa.ports.keys())
imported_ports.update(aa.ports)
pattern.ports.update(imported_ports)
return pattern
@ -178,7 +184,14 @@ def data_to_ports_flat(
name, property_string = label.string.split(':', 1)
properties = property_string.split()
ptype = properties[0] if len(properties) > 0 else 'unk'
angle_deg = float(properties[1]) if len(properties) > 1 else numpy.inf
if len(properties) > 1:
try:
angle_deg = float(properties[1])
except ValueError:
logger.warning(f'Invalid port label "{label.string}" in pattern "{pstr}" (bad angle)')
continue
else:
angle_deg = numpy.inf
xy = label.offset
angle = numpy.deg2rad(angle_deg) if numpy.isfinite(angle_deg) else None
@ -190,4 +203,3 @@ def data_to_ports_flat(
pattern.ports.update(local_ports)
return pattern