diff --git a/masque/file/gdsii.py b/masque/file/gdsii.py index f589ad8..116fa07 100644 --- a/masque/file/gdsii.py +++ b/masque/file/gdsii.py @@ -637,6 +637,7 @@ def check_valid_names( max_length: Max allowed length """ + names = tuple(names) allowed_chars = set(string.ascii_letters + string.digits + '_?$') bad_chars = [ diff --git a/masque/file/svg.py b/masque/file/svg.py index f235b50..621bcdb 100644 --- a/masque/file/svg.py +++ b/masque/file/svg.py @@ -30,6 +30,21 @@ def _ref_to_svg_transform(ref) -> str: return f'matrix({a:g} {b:g} {c:g} {d:g} {e:g} {f:g})' +def _make_svg_ids(names: Mapping[str, Pattern]) -> dict[str, str]: + svg_ids: dict[str, str] = {} + seen_ids: set[str] = set() + for name in names: + base_id = mangle_name(name) + svg_id = base_id + suffix = 1 + while svg_id in seen_ids: + suffix += 1 + svg_id = f'{base_id}_{suffix}' + seen_ids.add(svg_id) + svg_ids[name] = svg_id + return svg_ids + + def writefile( library: Mapping[str, Pattern], top: str, @@ -81,10 +96,11 @@ def writefile( # Create file svg = svgwrite.Drawing(filename, profile='full', viewBox=viewbox_string, debug=(not custom_attributes)) + svg_ids = _make_svg_ids(library) # Now create a group for each pattern and add in any Boundary and Use elements for name, pat in library.items(): - svg_group = svg.g(id=mangle_name(name), fill='blue', stroke='red') + svg_group = svg.g(id=svg_ids[name], fill='blue', stroke='red') for layer, shapes in pat.shapes.items(): for shape in shapes: @@ -123,11 +139,11 @@ def writefile( continue for ref in refs: transform = _ref_to_svg_transform(ref) - use = svg.use(href='#' + mangle_name(target), transform=transform) + use = svg.use(href='#' + svg_ids[target], transform=transform) svg_group.add(use) svg.defs.add(svg_group) - svg.add(svg.use(href='#' + mangle_name(top))) + svg.add(svg.use(href='#' + svg_ids[top])) svg.save() diff --git a/masque/file/utils.py b/masque/file/utils.py index 25bc61d..58c7573 100644 --- a/masque/file/utils.py +++ b/masque/file/utils.py @@ -33,6 +33,12 @@ def preflight( Run a standard set of useful operations and checks, usually done immediately prior to writing to a file (or immediately after reading). + Note that this helper is not copy-isolating. When `sort=True`, it constructs a new + `Library` wrapper around the same `Pattern` objects after sorting them in place, so + later mutating preflight steps such as `prune_empty_patterns` and + `wrap_repeated_shapes` may still mutate caller-owned patterns. Callers that need + isolation should deep-copy the library before calling `preflight()`. + Args: sort: Whether to sort the patterns based on their names, and optionaly sort the pattern contents. Default True. Useful for reproducible builds. @@ -145,7 +151,11 @@ def tmpfile(path: str | pathlib.Path) -> Iterator[IO[bytes]]: path = pathlib.Path(path) suffixes = ''.join(path.suffixes) with tempfile.NamedTemporaryFile(suffix=suffixes, delete=False) as tmp_stream: - yield tmp_stream + try: + yield tmp_stream + except Exception: + pathlib.Path(tmp_stream.name).unlink(missing_ok=True) + raise try: shutil.move(tmp_stream.name, path) diff --git a/masque/library.py b/masque/library.py index 9d1f1b7..825dbf0 100644 --- a/masque/library.py +++ b/masque/library.py @@ -22,7 +22,7 @@ import copy from pprint import pformat from collections import defaultdict from abc import ABCMeta, abstractmethod -from graphlib import TopologicalSorter +from graphlib import TopologicalSorter, CycleError import numpy from numpy.typing import ArrayLike, NDArray @@ -538,6 +538,7 @@ class ILibraryView(Mapping[str, 'Pattern'], metaclass=ABCMeta): raise LibraryError('visit_* functions returned a new `Pattern` object' ' but no top-level name was provided in `hierarchy`') + del cast('ILibrary', self)[name] cast('ILibrary', self)[name] = pattern return self @@ -617,7 +618,13 @@ class ILibraryView(Mapping[str, 'Pattern'], metaclass=ABCMeta): Return: Topologically sorted list of pattern names. """ - return cast('list[str]', list(TopologicalSorter(self.child_graph(dangling=dangling)).static_order())) + try: + return cast('list[str]', list(TopologicalSorter(self.child_graph(dangling=dangling)).static_order())) + except CycleError as exc: + cycle = exc.args[1] if len(exc.args) > 1 else None + if cycle is None: + raise LibraryError('Cycle found while building child order') from exc + raise LibraryError(f'Cycle found while building child order: {cycle}') from exc def find_refs_local( self, @@ -916,8 +923,8 @@ class ILibrary(ILibraryView, MutableMapping[str, 'Pattern'], metaclass=ABCMeta): (default). Returns: - A mapping of `{old_name: new_name}` for all `old_name`s in `other`. Unchanged - names map to themselves. + A mapping of `{old_name: new_name}` for all names in `other` which were + renamed while being added. Unchanged names are omitted. Raises: `LibraryError` if a duplicate name is encountered even after applying `rename_theirs()`. @@ -926,8 +933,13 @@ class ILibrary(ILibraryView, MutableMapping[str, 'Pattern'], metaclass=ABCMeta): duplicates = set(self.keys()) & set(other.keys()) if not duplicates: - for key in other: - self._merge(key, other, key) + if mutate_other: + temp = other + else: + temp = Library(copy.deepcopy(dict(other))) + + for key in temp: + self._merge(key, temp, key) return {} if mutate_other: diff --git a/masque/ports.py b/masque/ports.py index 8d7b6a6..3a695fb 100644 --- a/masque/ports.py +++ b/masque/ports.py @@ -510,10 +510,19 @@ class PortList(metaclass=ABCMeta): if missing_invals: raise PortError(f'`map_in` values not present in other device: {missing_invals}') + map_in_counts = Counter(map_in.values()) + conflicts_in = {kk for kk, vv in map_in_counts.items() if vv > 1} + if conflicts_in: + raise PortError(f'Duplicate values in `map_in`: {conflicts_in}') + missing_outkeys = set(map_out.keys()) - other if missing_outkeys: raise PortError(f'`map_out` keys not present in other device: {missing_outkeys}') + connected_outkeys = set(map_out.keys()) & set(map_in.values()) + if connected_outkeys: + raise PortError(f'`map_out` keys conflict with connected ports: {connected_outkeys}') + orig_remaining = set(self.ports.keys()) - set(map_in.keys()) other_remaining = other - set(map_out.keys()) - set(map_in.values()) mapped_vals = set(map_out.values()) diff --git a/masque/ref.py b/masque/ref.py index b012365..268d8d4 100644 --- a/masque/ref.py +++ b/masque/ref.py @@ -236,7 +236,10 @@ class Ref( bounds = numpy.vstack((numpy.min(corners, axis=0), numpy.max(corners, axis=0))) * self.scale + [self.offset] return bounds - return self.as_pattern(pattern=pattern).get_bounds(library) + + single_ref = self.deepcopy() + single_ref.repetition = None + return single_ref.as_pattern(pattern=pattern).get_bounds(library) def __repr__(self) -> str: rotation = f' r{numpy.rad2deg(self.rotation):g}' if self.rotation != 0 else '' diff --git a/masque/repetition.py b/masque/repetition.py index 99e1082..1e12fcd 100644 --- a/masque/repetition.py +++ b/masque/repetition.py @@ -184,6 +184,8 @@ class Grid(Repetition): def a_count(self, val: int) -> None: if val != int(val): raise PatternError('a_count must be convertable to an int!') + if int(val) < 1: + raise PatternError(f'Repetition has too-small a_count: {val}') self._a_count = int(val) # b_count property @@ -195,6 +197,8 @@ class Grid(Repetition): def b_count(self, val: int) -> None: if val != int(val): raise PatternError('b_count must be convertable to an int!') + if int(val) < 1: + raise PatternError(f'Repetition has too-small b_count: {val}') self._b_count = int(val) @property @@ -325,7 +329,22 @@ class Arbitrary(Repetition): @displacements.setter def displacements(self, val: ArrayLike) -> None: - vala = numpy.array(val, dtype=float) + try: + vala = numpy.array(val, dtype=float) + except (TypeError, ValueError) as exc: + raise PatternError('displacements must be convertible to an Nx2 ndarray') from exc + + if vala.size == 0: + self._displacements = numpy.empty((0, 2), dtype=float) + return + + if vala.ndim == 1: + if vala.size != 2: + raise PatternError('displacements must be convertible to an Nx2 ndarray') + vala = vala.reshape(1, 2) + elif vala.ndim != 2 or vala.shape[1] != 2: + raise PatternError('displacements must be convertible to an Nx2 ndarray') + order = numpy.lexsort(vala.T[::-1]) # sortrows self._displacements = vala[order] diff --git a/masque/shapes/poly_collection.py b/masque/shapes/poly_collection.py index cd233ba..6c23da7 100644 --- a/masque/shapes/poly_collection.py +++ b/masque/shapes/poly_collection.py @@ -219,7 +219,7 @@ class PolyCollection(Shape): (offset, scale / norm_value, rotation, False), lambda: PolyCollection( vertex_lists=rotated_vertices * norm_value, - vertex_offsets=self._vertex_offsets, + vertex_offsets=self._vertex_offsets.copy(), ), ) diff --git a/masque/test/test_gdsii.py b/masque/test/test_gdsii.py index 7ce8c88..7a2f5b1 100644 --- a/masque/test/test_gdsii.py +++ b/masque/test/test_gdsii.py @@ -1,8 +1,10 @@ from pathlib import Path from typing import cast import numpy +import pytest from numpy.testing import assert_equal, assert_allclose +from ..error import LibraryError from ..pattern import Pattern from ..library import Library from ..file import gdsii @@ -69,3 +71,10 @@ def test_gdsii_annotations(tmp_path: Path) -> None: read_ann = read_lib["cell"].shapes[(1, 0)][0].annotations assert read_ann is not None assert read_ann["1"] == ["hello"] + + +def test_gdsii_check_valid_names_validates_generator_lengths() -> None: + names = (name for name in ("a" * 40,)) + + with pytest.raises(LibraryError, match="invalid names"): + gdsii.check_valid_names(names) diff --git a/masque/test/test_library.py b/masque/test/test_library.py index 6ac8536..3b731ad 100644 --- a/masque/test/test_library.py +++ b/masque/test/test_library.py @@ -221,6 +221,77 @@ def test_library_rename() -> None: assert "old" not in lib["parent"].refs +def test_library_dfs_can_replace_existing_patterns() -> None: + lib = Library() + child = Pattern() + lib["child"] = child + top = Pattern() + top.ref("child") + lib["top"] = top + + replacement_top = Pattern(ports={"T": Port((1, 2), 0)}) + replacement_child = Pattern(ports={"C": Port((3, 4), 0)}) + + def visit_after(pattern: Pattern, hierarchy: tuple[str | None, ...], **kwargs) -> Pattern: # noqa: ARG001 + if hierarchy[-1] == "child": + return replacement_child + if hierarchy[-1] == "top": + return replacement_top + return pattern + + lib.dfs(lib["top"], visit_after=visit_after, hierarchy=("top",), transform=True) + + assert lib["top"] is replacement_top + assert lib["child"] is replacement_child + + +def test_lazy_library_dfs_can_replace_existing_patterns() -> None: + lib = LazyLibrary() + lib["child"] = lambda: Pattern() + lib["top"] = lambda: Pattern(refs={"child": []}) + + top = lib["top"] + top.ref("child") + + replacement_top = Pattern(ports={"T": Port((1, 2), 0)}) + replacement_child = Pattern(ports={"C": Port((3, 4), 0)}) + + def visit_after(pattern: Pattern, hierarchy: tuple[str | None, ...], **kwargs) -> Pattern: # noqa: ARG001 + if hierarchy[-1] == "child": + return replacement_child + if hierarchy[-1] == "top": + return replacement_top + return pattern + + lib.dfs(top, visit_after=visit_after, hierarchy=("top",), transform=True) + + assert lib["top"] is replacement_top + assert lib["child"] is replacement_child + + +def test_library_add_no_duplicates_respects_mutate_other_false() -> None: + src_pat = Pattern(ports={"A": Port((0, 0), 0)}) + lib = Library({"a": Pattern()}) + + lib.add({"b": src_pat}, mutate_other=False) + + assert lib["b"] is not src_pat + lib["b"].ports["A"].offset[0] = 123 + assert tuple(src_pat.ports["A"].offset) == (0.0, 0.0) + + +def test_library_add_returns_only_renamed_entries() -> None: + lib = Library({"a": Pattern(), "_shape": Pattern()}) + + assert lib.add({"b": Pattern(), "c": Pattern()}, mutate_other=False) == {} + + rename_map = lib.add({"_shape": Pattern(), "keep": Pattern()}, mutate_other=False) + + assert set(rename_map) == {"_shape"} + assert rename_map["_shape"] != "_shape" + assert "keep" not in rename_map + + def test_library_subtree() -> None: lib = Library() lib["a"] = Pattern() @@ -234,6 +305,26 @@ def test_library_subtree() -> None: assert "c" not in sub +def test_library_child_order_cycle_raises_library_error() -> None: + lib = Library() + lib["a"] = Pattern() + lib["a"].ref("b") + lib["b"] = Pattern() + lib["b"].ref("a") + + with pytest.raises(LibraryError, match="Cycle found while building child order"): + lib.child_order() + + +def test_library_find_refs_global_cycle_raises_library_error() -> None: + lib = Library() + lib["a"] = Pattern() + lib["a"].ref("a") + + with pytest.raises(LibraryError, match="Cycle found while building child order"): + lib.find_refs_global("a") + + def test_library_get_name() -> None: lib = Library() lib["cell"] = Pattern() diff --git a/masque/test/test_ports.py b/masque/test/test_ports.py index 0f60809..4e7d097 100644 --- a/masque/test/test_ports.py +++ b/masque/test/test_ports.py @@ -4,6 +4,7 @@ from numpy import pi from ..ports import Port, PortList from ..error import PortError +from ..pattern import Pattern def test_port_init() -> None: @@ -227,3 +228,32 @@ def test_port_list_plugged_mismatch() -> None: pl = MyPorts() with pytest.raises(PortError): pl.plugged({"A": "B"}) + + +def test_port_list_check_ports_duplicate_map_in_values_raise() -> None: + class MyPorts(PortList): + def __init__(self) -> None: + self._ports = {"A": Port((0, 0), 0), "B": Port((0, 0), 0)} + + @property + def ports(self) -> dict[str, Port]: + return self._ports + + @ports.setter + def ports(self, val: dict[str, Port]) -> None: + self._ports = val + + pl = MyPorts() + with pytest.raises(PortError, match="Duplicate values in `map_in`"): + pl.check_ports({"X", "Y"}, map_in={"A": "X", "B": "X"}) + assert set(pl.ports) == {"A", "B"} + + +def test_pattern_plug_rejects_map_out_on_connected_ports_atomically() -> None: + host = Pattern(ports={"A": Port((0, 0), 0)}) + other = Pattern(ports={"X": Port((0, 0), pi), "Y": Port((5, 0), 0)}) + + with pytest.raises(PortError, match="`map_out` keys conflict with connected ports"): + host.plug(other, {"A": "X"}, map_out={"X": "renamed", "Y": "out"}, append=True) + + assert set(host.ports) == {"A"} diff --git a/masque/test/test_ports2data.py b/masque/test/test_ports2data.py index 72f6870..3f642ab 100644 --- a/masque/test/test_ports2data.py +++ b/masque/test/test_ports2data.py @@ -1,10 +1,13 @@ import numpy +import pytest from numpy.testing import assert_allclose from ..utils.ports2data import ports_to_data, data_to_ports from ..pattern import Pattern from ..ports import Port from ..library import Library +from ..error import PortError +from ..repetition import Grid def test_ports2data_roundtrip() -> None: @@ -74,3 +77,56 @@ def test_data_to_ports_hierarchical_scaled_ref() -> None: assert_allclose(parent.ports["A"].offset, [100, 110], atol=1e-10) assert parent.ports["A"].rotation is not None assert_allclose(parent.ports["A"].rotation, numpy.pi / 2, atol=1e-10) + + +def test_data_to_ports_hierarchical_repeated_ref_warns_and_keeps_best_effort( + caplog: pytest.LogCaptureFixture, + ) -> None: + lib = Library() + + child = Pattern() + layer = (10, 0) + child.label(layer=layer, string="A:type1 0", offset=(5, 0)) + lib["child"] = child + + parent = Pattern() + parent.ref("child", repetition=Grid(a_vector=(100, 0), a_count=3)) + + caplog.set_level("WARNING") + data_to_ports([layer], lib, parent, max_depth=1) + + assert "A" in parent.ports + assert_allclose(parent.ports["A"].offset, [5, 0], atol=1e-10) + assert any("importing only the base instance ports" in record.message for record in caplog.records) + + +def test_data_to_ports_hierarchical_collision_is_atomic() -> None: + lib = Library() + + child = Pattern() + layer = (10, 0) + child.label(layer=layer, string="A:type1 0", offset=(5, 0)) + lib["child"] = child + + parent = Pattern() + parent.ref("child", offset=(0, 0)) + parent.ref("child", offset=(10, 0)) + + with pytest.raises(PortError, match="Device ports conflict with existing ports"): + data_to_ports([layer], lib, parent, max_depth=1) + + assert not parent.ports + + +def test_data_to_ports_flat_bad_angle_warns_and_skips( + caplog: pytest.LogCaptureFixture, + ) -> None: + layer = (10, 0) + pat = Pattern() + pat.label(layer=layer, string="A:type1 nope", offset=(5, 0)) + + caplog.set_level("WARNING") + data_to_ports([layer], {}, pat) + + assert not pat.ports + assert any('bad angle' in record.message for record in caplog.records) diff --git a/masque/test/test_ref.py b/masque/test/test_ref.py index d3e9778..de330fa 100644 --- a/masque/test/test_ref.py +++ b/masque/test/test_ref.py @@ -64,6 +64,22 @@ def test_ref_get_bounds() -> None: assert_equal(bounds, [[10, 10], [20, 20]]) +def test_ref_get_bounds_single_ignores_repetition_for_non_manhattan_rotation() -> None: + sub_pat = Pattern() + sub_pat.rect((1, 0), xmin=0, xmax=1, ymin=0, ymax=2) + + rep = Grid(a_vector=(5, 0), b_vector=(0, 7), a_count=3, b_count=2) + ref = Ref(offset=(10, 20), rotation=pi / 4, repetition=rep) + + bounds = ref.get_bounds_single(sub_pat) + repeated_bounds = ref.get_bounds(sub_pat) + + assert bounds is not None + assert repeated_bounds is not None + assert repeated_bounds[1, 0] > bounds[1, 0] + assert repeated_bounds[1, 1] > bounds[1, 1] + + def test_ref_copy() -> None: ref1 = Ref(offset=(1, 2), rotation=0.5, annotations={"a": [1]}) ref2 = ref1.copy() diff --git a/masque/test/test_repetition.py b/masque/test/test_repetition.py index f423ab2..0d0be41 100644 --- a/masque/test/test_repetition.py +++ b/masque/test/test_repetition.py @@ -1,7 +1,9 @@ +import pytest from numpy.testing import assert_equal, assert_allclose from numpy import pi from ..repetition import Grid, Arbitrary +from ..error import PatternError def test_grid_displacements() -> None: @@ -51,6 +53,30 @@ def test_arbitrary_transform() -> None: assert_allclose(arb.displacements, [[0, -10]], atol=1e-10) +def test_arbitrary_empty_repetition_is_allowed() -> None: + arb = Arbitrary([]) + assert arb.displacements.shape == (0, 2) + assert arb.get_bounds() is None + + +def test_arbitrary_rejects_non_nx2_displacements() -> None: + for displacements in ([[1], [2]], [[1, 2, 3]], [1, 2, 3]): + with pytest.raises(PatternError, match='displacements must be convertible to an Nx2 ndarray'): + Arbitrary(displacements) + + +def test_grid_count_setters_reject_nonpositive_values() -> None: + for attr, value, message in ( + ('a_count', 0, 'a_count'), + ('a_count', -1, 'a_count'), + ('b_count', 0, 'b_count'), + ('b_count', -1, 'b_count'), + ): + grid = Grid(a_vector=(10, 0), b_vector=(0, 5), a_count=2, b_count=2) + with pytest.raises(PatternError, match=message): + setattr(grid, attr, value) + + def test_repetition_less_equal_includes_equality() -> None: grid_a = Grid(a_vector=(10, 0), a_count=2) grid_b = Grid(a_vector=(10, 0), a_count=2) diff --git a/masque/test/test_shape_advanced.py b/masque/test/test_shape_advanced.py index 2dec264..350e8f0 100644 --- a/masque/test/test_shape_advanced.py +++ b/masque/test/test_shape_advanced.py @@ -212,3 +212,27 @@ def test_poly_collection_valid() -> None: assert len(sorted_shapes) == 4 # Just verify it doesn't crash and is stable assert sorted(sorted_shapes) == sorted_shapes + + +def test_poly_collection_normalized_form_reconstruction_is_independent() -> None: + pc = PolyCollection([[0, 0], [1, 0], [0, 1]], [0]) + _intrinsic, _extrinsic, rebuild = pc.normalized_form(1) + + clone = rebuild() + clone.vertex_offsets[:] = [5] + + assert_equal(pc.vertex_offsets, [0]) + assert_equal(clone.vertex_offsets, [5]) + + +def test_poly_collection_normalized_form_rebuilds_independent_clones() -> None: + pc = PolyCollection([[0, 0], [1, 0], [0, 1]], [0]) + _intrinsic, _extrinsic, rebuild = pc.normalized_form(1) + + first = rebuild() + second = rebuild() + first.vertex_offsets[:] = [7] + + assert_equal(first.vertex_offsets, [7]) + assert_equal(second.vertex_offsets, [0]) + assert_equal(pc.vertex_offsets, [0]) diff --git a/masque/test/test_svg.py b/masque/test/test_svg.py index a3261b6..b637853 100644 --- a/masque/test/test_svg.py +++ b/masque/test/test_svg.py @@ -68,3 +68,31 @@ def test_svg_ref_mirroring_changes_affine_transform(tmp_path: Path) -> None: assert_allclose(plain_transform, (0, 2, -2, 0, 3, 4), atol=1e-10) assert_allclose(mirrored_transform, (0, 2, 2, 0, 3, 4), atol=1e-10) + + +def test_svg_uses_unique_ids_for_colliding_mangled_names(tmp_path: Path) -> None: + lib = Library() + first = Pattern() + first.polygon("1", vertices=[[0, 0], [1, 0], [0, 1]]) + lib["a b"] = first + + second = Pattern() + second.polygon("1", vertices=[[0, 0], [2, 0], [0, 2]]) + lib["a-b"] = second + + top = Pattern() + top.ref("a b") + top.ref("a-b", offset=(5, 0)) + lib["top"] = top + + svg_path = tmp_path / "colliding_ids.svg" + svg.writefile(lib, "top", str(svg_path)) + + root = ET.fromstring(svg_path.read_text()) + ids = [group.attrib["id"] for group in root.iter(f"{SVG_NS}g")] + hrefs = [use.attrib[XLINK_HREF] for use in root.iter(f"{SVG_NS}use")] + + assert ids.count("a_b") == 1 + assert len(set(ids)) == len(ids) + assert "#a_b" in hrefs + assert "#a_b_2" in hrefs diff --git a/masque/test/test_utils.py b/masque/test/test_utils.py index 45e347e..0511a24 100644 --- a/masque/test/test_utils.py +++ b/masque/test/test_utils.py @@ -1,8 +1,14 @@ +from pathlib import Path + import numpy from numpy.testing import assert_equal, assert_allclose from numpy import pi +import pytest from ..utils import remove_duplicate_vertices, remove_colinear_vertices, poly_contains_points, rotation_matrix_2d, apply_transforms, DeferredDict +from ..file.utils import tmpfile +from ..utils.curves import bezier +from ..error import PatternError def test_remove_duplicate_vertices() -> None: @@ -88,6 +94,19 @@ def test_apply_transforms_advanced() -> None: assert_allclose(combined[0], [0, 10, pi / 2, 1, 1], atol=1e-10) +def test_bezier_validates_weight_length() -> None: + with pytest.raises(PatternError, match='one entry per control point'): + bezier([[0, 0], [1, 1]], [0, 0.5, 1], weights=[1]) + + with pytest.raises(PatternError, match='one entry per control point'): + bezier([[0, 0], [1, 1]], [0, 0.5, 1], weights=[1, 2, 3]) + + +def test_bezier_accepts_exact_weight_count() -> None: + samples = bezier([[0, 0], [1, 1]], [0, 0.5, 1], weights=[1, 2]) + assert_allclose(samples, [[0, 0], [2 / 3, 2 / 3], [1, 1]], atol=1e-10) + + def test_deferred_dict_accessors_resolve_values_once() -> None: calls = 0 @@ -104,3 +123,48 @@ def test_deferred_dict_accessors_resolve_values_once() -> None: assert list(deferred.values()) == [7] assert list(deferred.items()) == [("x", 7)] assert calls == 1 + + +def test_deferred_dict_mutating_accessors_preserve_value_semantics() -> None: + calls = 0 + + def make_value() -> int: + nonlocal calls + calls += 1 + return 7 + + deferred = DeferredDict[str, int]() + + assert deferred.setdefault("x", 5) == 5 + assert deferred["x"] == 5 + + assert deferred.setdefault("y", make_value) == 7 + assert deferred["y"] == 7 + assert calls == 1 + + copy_deferred = deferred.copy() + assert isinstance(copy_deferred, DeferredDict) + assert copy_deferred["x"] == 5 + assert copy_deferred["y"] == 7 + assert calls == 1 + + assert deferred.pop("x") == 5 + assert deferred.pop("missing", 9) == 9 + assert deferred.popitem() == ("y", 7) + + +def test_tmpfile_cleans_up_on_exception(tmp_path: Path) -> None: + target = tmp_path / "out.txt" + temp_path = None + + try: + with tmpfile(target) as stream: + temp_path = Path(stream.name) + stream.write(b"hello") + raise RuntimeError("boom") + except RuntimeError: + pass + + assert temp_path is not None + assert not target.exists() + assert not temp_path.exists() diff --git a/masque/utils/curves.py b/masque/utils/curves.py index 2348678..3a7671b 100644 --- a/masque/utils/curves.py +++ b/masque/utils/curves.py @@ -2,6 +2,8 @@ import numpy from numpy.typing import ArrayLike, NDArray from numpy import pi +from ..error import PatternError + try: from numpy import trapezoid except ImportError: @@ -31,6 +33,11 @@ def bezier( tt = numpy.asarray(tt) nn = nodes.shape[0] weights = numpy.ones(nn) if weights is None else numpy.asarray(weights) + if weights.ndim != 1 or weights.shape[0] != nn: + raise PatternError( + f'weights must be a 1D array with one entry per control point; ' + f'got shape {weights.shape} for {nn} control points' + ) with numpy.errstate(divide='ignore'): umul = (tt / (1 - tt)).clip(max=1) diff --git a/masque/utils/deferreddict.py b/masque/utils/deferreddict.py index def9b10..70893c0 100644 --- a/masque/utils/deferreddict.py +++ b/masque/utils/deferreddict.py @@ -5,6 +5,7 @@ from functools import lru_cache Key = TypeVar('Key') Value = TypeVar('Value') +_MISSING = object() class DeferredDict(dict, Generic[Key, Value]): @@ -46,6 +47,15 @@ class DeferredDict(dict, Generic[Key, Value]): return default return self[key] + def setdefault(self, key: Key, default: Value | Callable[[], Value] | None = None) -> Value | None: + if key in self: + return self[key] + if callable(default): + self[key] = default + else: + self.set_const(key, default) + return self[key] + def items(self) -> Iterator[tuple[Key, Value]]: for key in self.keys(): yield key, self[key] @@ -65,6 +75,25 @@ class DeferredDict(dict, Generic[Key, Value]): else: self.set_const(k, v) + def pop(self, key: Key, default: Value | object = _MISSING) -> Value: + if key in self: + value = self[key] + dict.pop(self, key) + return value + if default is _MISSING: + raise KeyError(key) + return default # type: ignore[return-value] + + def popitem(self) -> tuple[Key, Value]: + key, value_func = dict.popitem(self) + return key, value_func() + + def copy(self) -> 'DeferredDict[Key, Value]': + new = DeferredDict[Key, Value]() + for key in self.keys(): + dict.__setitem__(new, key, dict.__getitem__(self, key)) + return new + def __repr__(self) -> str: return '' diff --git a/masque/utils/ports2data.py b/masque/utils/ports2data.py index c7f42e1..44a0ec3 100644 --- a/masque/utils/ports2data.py +++ b/masque/utils/ports2data.py @@ -122,6 +122,7 @@ def data_to_ports( if not found_ports: return pattern + imported_ports: dict[str, Port] = {} for target, refs in pattern.refs.items(): if target is None: continue @@ -133,9 +134,14 @@ def data_to_ports( if not aa.ports: break + if ref.repetition is not None: + logger.warning(f'Pattern {name if name else pattern} has repeated ref to {target!r}; ' + 'data_to_ports() is importing only the base instance ports') aa.apply_ref_transform(ref) - pattern.check_ports(other_names=aa.ports.keys()) - pattern.ports.update(aa.ports) + Pattern(ports={**pattern.ports, **imported_ports}).check_ports(other_names=aa.ports.keys()) + imported_ports.update(aa.ports) + + pattern.ports.update(imported_ports) return pattern @@ -178,7 +184,14 @@ def data_to_ports_flat( name, property_string = label.string.split(':', 1) properties = property_string.split() ptype = properties[0] if len(properties) > 0 else 'unk' - angle_deg = float(properties[1]) if len(properties) > 1 else numpy.inf + if len(properties) > 1: + try: + angle_deg = float(properties[1]) + except ValueError: + logger.warning(f'Invalid port label "{label.string}" in pattern "{pstr}" (bad angle)') + continue + else: + angle_deg = numpy.inf xy = label.offset angle = numpy.deg2rad(angle_deg) if numpy.isfinite(angle_deg) else None @@ -190,4 +203,3 @@ def data_to_ports_flat( pattern.ports.update(local_ports) return pattern -