diff --git a/masque/file/gdsii.py b/masque/file/gdsii.py index 116fa07..f589ad8 100644 --- a/masque/file/gdsii.py +++ b/masque/file/gdsii.py @@ -637,7 +637,6 @@ def check_valid_names( max_length: Max allowed length """ - names = tuple(names) allowed_chars = set(string.ascii_letters + string.digits + '_?$') bad_chars = [ diff --git a/masque/file/svg.py b/masque/file/svg.py index 621bcdb..f235b50 100644 --- a/masque/file/svg.py +++ b/masque/file/svg.py @@ -30,21 +30,6 @@ def _ref_to_svg_transform(ref) -> str: return f'matrix({a:g} {b:g} {c:g} {d:g} {e:g} {f:g})' -def _make_svg_ids(names: Mapping[str, Pattern]) -> dict[str, str]: - svg_ids: dict[str, str] = {} - seen_ids: set[str] = set() - for name in names: - base_id = mangle_name(name) - svg_id = base_id - suffix = 1 - while svg_id in seen_ids: - suffix += 1 - svg_id = f'{base_id}_{suffix}' - seen_ids.add(svg_id) - svg_ids[name] = svg_id - return svg_ids - - def writefile( library: Mapping[str, Pattern], top: str, @@ -96,11 +81,10 @@ def writefile( # Create file svg = svgwrite.Drawing(filename, profile='full', viewBox=viewbox_string, debug=(not custom_attributes)) - svg_ids = _make_svg_ids(library) # Now create a group for each pattern and add in any Boundary and Use elements for name, pat in library.items(): - svg_group = svg.g(id=svg_ids[name], fill='blue', stroke='red') + svg_group = svg.g(id=mangle_name(name), fill='blue', stroke='red') for layer, shapes in pat.shapes.items(): for shape in shapes: @@ -139,11 +123,11 @@ def writefile( continue for ref in refs: transform = _ref_to_svg_transform(ref) - use = svg.use(href='#' + svg_ids[target], transform=transform) + use = svg.use(href='#' + mangle_name(target), transform=transform) svg_group.add(use) svg.defs.add(svg_group) - svg.add(svg.use(href='#' + svg_ids[top])) + svg.add(svg.use(href='#' + mangle_name(top))) svg.save() diff --git a/masque/file/utils.py b/masque/file/utils.py index 58c7573..25bc61d 100644 --- a/masque/file/utils.py +++ b/masque/file/utils.py @@ -33,12 +33,6 @@ def preflight( Run a standard set of useful operations and checks, usually done immediately prior to writing to a file (or immediately after reading). - Note that this helper is not copy-isolating. When `sort=True`, it constructs a new - `Library` wrapper around the same `Pattern` objects after sorting them in place, so - later mutating preflight steps such as `prune_empty_patterns` and - `wrap_repeated_shapes` may still mutate caller-owned patterns. Callers that need - isolation should deep-copy the library before calling `preflight()`. - Args: sort: Whether to sort the patterns based on their names, and optionaly sort the pattern contents. Default True. Useful for reproducible builds. @@ -151,11 +145,7 @@ def tmpfile(path: str | pathlib.Path) -> Iterator[IO[bytes]]: path = pathlib.Path(path) suffixes = ''.join(path.suffixes) with tempfile.NamedTemporaryFile(suffix=suffixes, delete=False) as tmp_stream: - try: - yield tmp_stream - except Exception: - pathlib.Path(tmp_stream.name).unlink(missing_ok=True) - raise + yield tmp_stream try: shutil.move(tmp_stream.name, path) diff --git a/masque/library.py b/masque/library.py index 825dbf0..9d1f1b7 100644 --- a/masque/library.py +++ b/masque/library.py @@ -22,7 +22,7 @@ import copy from pprint import pformat from collections import defaultdict from abc import ABCMeta, abstractmethod -from graphlib import TopologicalSorter, CycleError +from graphlib import TopologicalSorter import numpy from numpy.typing import ArrayLike, NDArray @@ -538,7 +538,6 @@ class ILibraryView(Mapping[str, 'Pattern'], metaclass=ABCMeta): raise LibraryError('visit_* functions returned a new `Pattern` object' ' but no top-level name was provided in `hierarchy`') - del cast('ILibrary', self)[name] cast('ILibrary', self)[name] = pattern return self @@ -618,13 +617,7 @@ class ILibraryView(Mapping[str, 'Pattern'], metaclass=ABCMeta): Return: Topologically sorted list of pattern names. """ - try: - return cast('list[str]', list(TopologicalSorter(self.child_graph(dangling=dangling)).static_order())) - except CycleError as exc: - cycle = exc.args[1] if len(exc.args) > 1 else None - if cycle is None: - raise LibraryError('Cycle found while building child order') from exc - raise LibraryError(f'Cycle found while building child order: {cycle}') from exc + return cast('list[str]', list(TopologicalSorter(self.child_graph(dangling=dangling)).static_order())) def find_refs_local( self, @@ -923,8 +916,8 @@ class ILibrary(ILibraryView, MutableMapping[str, 'Pattern'], metaclass=ABCMeta): (default). Returns: - A mapping of `{old_name: new_name}` for all names in `other` which were - renamed while being added. Unchanged names are omitted. + A mapping of `{old_name: new_name}` for all `old_name`s in `other`. Unchanged + names map to themselves. Raises: `LibraryError` if a duplicate name is encountered even after applying `rename_theirs()`. @@ -933,13 +926,8 @@ class ILibrary(ILibraryView, MutableMapping[str, 'Pattern'], metaclass=ABCMeta): duplicates = set(self.keys()) & set(other.keys()) if not duplicates: - if mutate_other: - temp = other - else: - temp = Library(copy.deepcopy(dict(other))) - - for key in temp: - self._merge(key, temp, key) + for key in other: + self._merge(key, other, key) return {} if mutate_other: diff --git a/masque/ports.py b/masque/ports.py index 3a695fb..8d7b6a6 100644 --- a/masque/ports.py +++ b/masque/ports.py @@ -510,19 +510,10 @@ class PortList(metaclass=ABCMeta): if missing_invals: raise PortError(f'`map_in` values not present in other device: {missing_invals}') - map_in_counts = Counter(map_in.values()) - conflicts_in = {kk for kk, vv in map_in_counts.items() if vv > 1} - if conflicts_in: - raise PortError(f'Duplicate values in `map_in`: {conflicts_in}') - missing_outkeys = set(map_out.keys()) - other if missing_outkeys: raise PortError(f'`map_out` keys not present in other device: {missing_outkeys}') - connected_outkeys = set(map_out.keys()) & set(map_in.values()) - if connected_outkeys: - raise PortError(f'`map_out` keys conflict with connected ports: {connected_outkeys}') - orig_remaining = set(self.ports.keys()) - set(map_in.keys()) other_remaining = other - set(map_out.keys()) - set(map_in.values()) mapped_vals = set(map_out.values()) diff --git a/masque/ref.py b/masque/ref.py index 268d8d4..b012365 100644 --- a/masque/ref.py +++ b/masque/ref.py @@ -236,10 +236,7 @@ class Ref( bounds = numpy.vstack((numpy.min(corners, axis=0), numpy.max(corners, axis=0))) * self.scale + [self.offset] return bounds - - single_ref = self.deepcopy() - single_ref.repetition = None - return single_ref.as_pattern(pattern=pattern).get_bounds(library) + return self.as_pattern(pattern=pattern).get_bounds(library) def __repr__(self) -> str: rotation = f' r{numpy.rad2deg(self.rotation):g}' if self.rotation != 0 else '' diff --git a/masque/repetition.py b/masque/repetition.py index 1e12fcd..99e1082 100644 --- a/masque/repetition.py +++ b/masque/repetition.py @@ -184,8 +184,6 @@ class Grid(Repetition): def a_count(self, val: int) -> None: if val != int(val): raise PatternError('a_count must be convertable to an int!') - if int(val) < 1: - raise PatternError(f'Repetition has too-small a_count: {val}') self._a_count = int(val) # b_count property @@ -197,8 +195,6 @@ class Grid(Repetition): def b_count(self, val: int) -> None: if val != int(val): raise PatternError('b_count must be convertable to an int!') - if int(val) < 1: - raise PatternError(f'Repetition has too-small b_count: {val}') self._b_count = int(val) @property @@ -329,22 +325,7 @@ class Arbitrary(Repetition): @displacements.setter def displacements(self, val: ArrayLike) -> None: - try: - vala = numpy.array(val, dtype=float) - except (TypeError, ValueError) as exc: - raise PatternError('displacements must be convertible to an Nx2 ndarray') from exc - - if vala.size == 0: - self._displacements = numpy.empty((0, 2), dtype=float) - return - - if vala.ndim == 1: - if vala.size != 2: - raise PatternError('displacements must be convertible to an Nx2 ndarray') - vala = vala.reshape(1, 2) - elif vala.ndim != 2 or vala.shape[1] != 2: - raise PatternError('displacements must be convertible to an Nx2 ndarray') - + vala = numpy.array(val, dtype=float) order = numpy.lexsort(vala.T[::-1]) # sortrows self._displacements = vala[order] diff --git a/masque/shapes/poly_collection.py b/masque/shapes/poly_collection.py index 6c23da7..cd233ba 100644 --- a/masque/shapes/poly_collection.py +++ b/masque/shapes/poly_collection.py @@ -219,7 +219,7 @@ class PolyCollection(Shape): (offset, scale / norm_value, rotation, False), lambda: PolyCollection( vertex_lists=rotated_vertices * norm_value, - vertex_offsets=self._vertex_offsets.copy(), + vertex_offsets=self._vertex_offsets, ), ) diff --git a/masque/test/test_gdsii.py b/masque/test/test_gdsii.py index 7a2f5b1..7ce8c88 100644 --- a/masque/test/test_gdsii.py +++ b/masque/test/test_gdsii.py @@ -1,10 +1,8 @@ from pathlib import Path from typing import cast import numpy -import pytest from numpy.testing import assert_equal, assert_allclose -from ..error import LibraryError from ..pattern import Pattern from ..library import Library from ..file import gdsii @@ -71,10 +69,3 @@ def test_gdsii_annotations(tmp_path: Path) -> None: read_ann = read_lib["cell"].shapes[(1, 0)][0].annotations assert read_ann is not None assert read_ann["1"] == ["hello"] - - -def test_gdsii_check_valid_names_validates_generator_lengths() -> None: - names = (name for name in ("a" * 40,)) - - with pytest.raises(LibraryError, match="invalid names"): - gdsii.check_valid_names(names) diff --git a/masque/test/test_library.py b/masque/test/test_library.py index 3b731ad..6ac8536 100644 --- a/masque/test/test_library.py +++ b/masque/test/test_library.py @@ -221,77 +221,6 @@ def test_library_rename() -> None: assert "old" not in lib["parent"].refs -def test_library_dfs_can_replace_existing_patterns() -> None: - lib = Library() - child = Pattern() - lib["child"] = child - top = Pattern() - top.ref("child") - lib["top"] = top - - replacement_top = Pattern(ports={"T": Port((1, 2), 0)}) - replacement_child = Pattern(ports={"C": Port((3, 4), 0)}) - - def visit_after(pattern: Pattern, hierarchy: tuple[str | None, ...], **kwargs) -> Pattern: # noqa: ARG001 - if hierarchy[-1] == "child": - return replacement_child - if hierarchy[-1] == "top": - return replacement_top - return pattern - - lib.dfs(lib["top"], visit_after=visit_after, hierarchy=("top",), transform=True) - - assert lib["top"] is replacement_top - assert lib["child"] is replacement_child - - -def test_lazy_library_dfs_can_replace_existing_patterns() -> None: - lib = LazyLibrary() - lib["child"] = lambda: Pattern() - lib["top"] = lambda: Pattern(refs={"child": []}) - - top = lib["top"] - top.ref("child") - - replacement_top = Pattern(ports={"T": Port((1, 2), 0)}) - replacement_child = Pattern(ports={"C": Port((3, 4), 0)}) - - def visit_after(pattern: Pattern, hierarchy: tuple[str | None, ...], **kwargs) -> Pattern: # noqa: ARG001 - if hierarchy[-1] == "child": - return replacement_child - if hierarchy[-1] == "top": - return replacement_top - return pattern - - lib.dfs(top, visit_after=visit_after, hierarchy=("top",), transform=True) - - assert lib["top"] is replacement_top - assert lib["child"] is replacement_child - - -def test_library_add_no_duplicates_respects_mutate_other_false() -> None: - src_pat = Pattern(ports={"A": Port((0, 0), 0)}) - lib = Library({"a": Pattern()}) - - lib.add({"b": src_pat}, mutate_other=False) - - assert lib["b"] is not src_pat - lib["b"].ports["A"].offset[0] = 123 - assert tuple(src_pat.ports["A"].offset) == (0.0, 0.0) - - -def test_library_add_returns_only_renamed_entries() -> None: - lib = Library({"a": Pattern(), "_shape": Pattern()}) - - assert lib.add({"b": Pattern(), "c": Pattern()}, mutate_other=False) == {} - - rename_map = lib.add({"_shape": Pattern(), "keep": Pattern()}, mutate_other=False) - - assert set(rename_map) == {"_shape"} - assert rename_map["_shape"] != "_shape" - assert "keep" not in rename_map - - def test_library_subtree() -> None: lib = Library() lib["a"] = Pattern() @@ -305,26 +234,6 @@ def test_library_subtree() -> None: assert "c" not in sub -def test_library_child_order_cycle_raises_library_error() -> None: - lib = Library() - lib["a"] = Pattern() - lib["a"].ref("b") - lib["b"] = Pattern() - lib["b"].ref("a") - - with pytest.raises(LibraryError, match="Cycle found while building child order"): - lib.child_order() - - -def test_library_find_refs_global_cycle_raises_library_error() -> None: - lib = Library() - lib["a"] = Pattern() - lib["a"].ref("a") - - with pytest.raises(LibraryError, match="Cycle found while building child order"): - lib.find_refs_global("a") - - def test_library_get_name() -> None: lib = Library() lib["cell"] = Pattern() diff --git a/masque/test/test_ports.py b/masque/test/test_ports.py index 4e7d097..0f60809 100644 --- a/masque/test/test_ports.py +++ b/masque/test/test_ports.py @@ -4,7 +4,6 @@ from numpy import pi from ..ports import Port, PortList from ..error import PortError -from ..pattern import Pattern def test_port_init() -> None: @@ -228,32 +227,3 @@ def test_port_list_plugged_mismatch() -> None: pl = MyPorts() with pytest.raises(PortError): pl.plugged({"A": "B"}) - - -def test_port_list_check_ports_duplicate_map_in_values_raise() -> None: - class MyPorts(PortList): - def __init__(self) -> None: - self._ports = {"A": Port((0, 0), 0), "B": Port((0, 0), 0)} - - @property - def ports(self) -> dict[str, Port]: - return self._ports - - @ports.setter - def ports(self, val: dict[str, Port]) -> None: - self._ports = val - - pl = MyPorts() - with pytest.raises(PortError, match="Duplicate values in `map_in`"): - pl.check_ports({"X", "Y"}, map_in={"A": "X", "B": "X"}) - assert set(pl.ports) == {"A", "B"} - - -def test_pattern_plug_rejects_map_out_on_connected_ports_atomically() -> None: - host = Pattern(ports={"A": Port((0, 0), 0)}) - other = Pattern(ports={"X": Port((0, 0), pi), "Y": Port((5, 0), 0)}) - - with pytest.raises(PortError, match="`map_out` keys conflict with connected ports"): - host.plug(other, {"A": "X"}, map_out={"X": "renamed", "Y": "out"}, append=True) - - assert set(host.ports) == {"A"} diff --git a/masque/test/test_ports2data.py b/masque/test/test_ports2data.py index 3f642ab..72f6870 100644 --- a/masque/test/test_ports2data.py +++ b/masque/test/test_ports2data.py @@ -1,13 +1,10 @@ import numpy -import pytest from numpy.testing import assert_allclose from ..utils.ports2data import ports_to_data, data_to_ports from ..pattern import Pattern from ..ports import Port from ..library import Library -from ..error import PortError -from ..repetition import Grid def test_ports2data_roundtrip() -> None: @@ -77,56 +74,3 @@ def test_data_to_ports_hierarchical_scaled_ref() -> None: assert_allclose(parent.ports["A"].offset, [100, 110], atol=1e-10) assert parent.ports["A"].rotation is not None assert_allclose(parent.ports["A"].rotation, numpy.pi / 2, atol=1e-10) - - -def test_data_to_ports_hierarchical_repeated_ref_warns_and_keeps_best_effort( - caplog: pytest.LogCaptureFixture, - ) -> None: - lib = Library() - - child = Pattern() - layer = (10, 0) - child.label(layer=layer, string="A:type1 0", offset=(5, 0)) - lib["child"] = child - - parent = Pattern() - parent.ref("child", repetition=Grid(a_vector=(100, 0), a_count=3)) - - caplog.set_level("WARNING") - data_to_ports([layer], lib, parent, max_depth=1) - - assert "A" in parent.ports - assert_allclose(parent.ports["A"].offset, [5, 0], atol=1e-10) - assert any("importing only the base instance ports" in record.message for record in caplog.records) - - -def test_data_to_ports_hierarchical_collision_is_atomic() -> None: - lib = Library() - - child = Pattern() - layer = (10, 0) - child.label(layer=layer, string="A:type1 0", offset=(5, 0)) - lib["child"] = child - - parent = Pattern() - parent.ref("child", offset=(0, 0)) - parent.ref("child", offset=(10, 0)) - - with pytest.raises(PortError, match="Device ports conflict with existing ports"): - data_to_ports([layer], lib, parent, max_depth=1) - - assert not parent.ports - - -def test_data_to_ports_flat_bad_angle_warns_and_skips( - caplog: pytest.LogCaptureFixture, - ) -> None: - layer = (10, 0) - pat = Pattern() - pat.label(layer=layer, string="A:type1 nope", offset=(5, 0)) - - caplog.set_level("WARNING") - data_to_ports([layer], {}, pat) - - assert not pat.ports - assert any('bad angle' in record.message for record in caplog.records) diff --git a/masque/test/test_ref.py b/masque/test/test_ref.py index de330fa..d3e9778 100644 --- a/masque/test/test_ref.py +++ b/masque/test/test_ref.py @@ -64,22 +64,6 @@ def test_ref_get_bounds() -> None: assert_equal(bounds, [[10, 10], [20, 20]]) -def test_ref_get_bounds_single_ignores_repetition_for_non_manhattan_rotation() -> None: - sub_pat = Pattern() - sub_pat.rect((1, 0), xmin=0, xmax=1, ymin=0, ymax=2) - - rep = Grid(a_vector=(5, 0), b_vector=(0, 7), a_count=3, b_count=2) - ref = Ref(offset=(10, 20), rotation=pi / 4, repetition=rep) - - bounds = ref.get_bounds_single(sub_pat) - repeated_bounds = ref.get_bounds(sub_pat) - - assert bounds is not None - assert repeated_bounds is not None - assert repeated_bounds[1, 0] > bounds[1, 0] - assert repeated_bounds[1, 1] > bounds[1, 1] - - def test_ref_copy() -> None: ref1 = Ref(offset=(1, 2), rotation=0.5, annotations={"a": [1]}) ref2 = ref1.copy() diff --git a/masque/test/test_repetition.py b/masque/test/test_repetition.py index 0d0be41..f423ab2 100644 --- a/masque/test/test_repetition.py +++ b/masque/test/test_repetition.py @@ -1,9 +1,7 @@ -import pytest from numpy.testing import assert_equal, assert_allclose from numpy import pi from ..repetition import Grid, Arbitrary -from ..error import PatternError def test_grid_displacements() -> None: @@ -53,30 +51,6 @@ def test_arbitrary_transform() -> None: assert_allclose(arb.displacements, [[0, -10]], atol=1e-10) -def test_arbitrary_empty_repetition_is_allowed() -> None: - arb = Arbitrary([]) - assert arb.displacements.shape == (0, 2) - assert arb.get_bounds() is None - - -def test_arbitrary_rejects_non_nx2_displacements() -> None: - for displacements in ([[1], [2]], [[1, 2, 3]], [1, 2, 3]): - with pytest.raises(PatternError, match='displacements must be convertible to an Nx2 ndarray'): - Arbitrary(displacements) - - -def test_grid_count_setters_reject_nonpositive_values() -> None: - for attr, value, message in ( - ('a_count', 0, 'a_count'), - ('a_count', -1, 'a_count'), - ('b_count', 0, 'b_count'), - ('b_count', -1, 'b_count'), - ): - grid = Grid(a_vector=(10, 0), b_vector=(0, 5), a_count=2, b_count=2) - with pytest.raises(PatternError, match=message): - setattr(grid, attr, value) - - def test_repetition_less_equal_includes_equality() -> None: grid_a = Grid(a_vector=(10, 0), a_count=2) grid_b = Grid(a_vector=(10, 0), a_count=2) diff --git a/masque/test/test_shape_advanced.py b/masque/test/test_shape_advanced.py index 350e8f0..2dec264 100644 --- a/masque/test/test_shape_advanced.py +++ b/masque/test/test_shape_advanced.py @@ -212,27 +212,3 @@ def test_poly_collection_valid() -> None: assert len(sorted_shapes) == 4 # Just verify it doesn't crash and is stable assert sorted(sorted_shapes) == sorted_shapes - - -def test_poly_collection_normalized_form_reconstruction_is_independent() -> None: - pc = PolyCollection([[0, 0], [1, 0], [0, 1]], [0]) - _intrinsic, _extrinsic, rebuild = pc.normalized_form(1) - - clone = rebuild() - clone.vertex_offsets[:] = [5] - - assert_equal(pc.vertex_offsets, [0]) - assert_equal(clone.vertex_offsets, [5]) - - -def test_poly_collection_normalized_form_rebuilds_independent_clones() -> None: - pc = PolyCollection([[0, 0], [1, 0], [0, 1]], [0]) - _intrinsic, _extrinsic, rebuild = pc.normalized_form(1) - - first = rebuild() - second = rebuild() - first.vertex_offsets[:] = [7] - - assert_equal(first.vertex_offsets, [7]) - assert_equal(second.vertex_offsets, [0]) - assert_equal(pc.vertex_offsets, [0]) diff --git a/masque/test/test_svg.py b/masque/test/test_svg.py index b637853..a3261b6 100644 --- a/masque/test/test_svg.py +++ b/masque/test/test_svg.py @@ -68,31 +68,3 @@ def test_svg_ref_mirroring_changes_affine_transform(tmp_path: Path) -> None: assert_allclose(plain_transform, (0, 2, -2, 0, 3, 4), atol=1e-10) assert_allclose(mirrored_transform, (0, 2, 2, 0, 3, 4), atol=1e-10) - - -def test_svg_uses_unique_ids_for_colliding_mangled_names(tmp_path: Path) -> None: - lib = Library() - first = Pattern() - first.polygon("1", vertices=[[0, 0], [1, 0], [0, 1]]) - lib["a b"] = first - - second = Pattern() - second.polygon("1", vertices=[[0, 0], [2, 0], [0, 2]]) - lib["a-b"] = second - - top = Pattern() - top.ref("a b") - top.ref("a-b", offset=(5, 0)) - lib["top"] = top - - svg_path = tmp_path / "colliding_ids.svg" - svg.writefile(lib, "top", str(svg_path)) - - root = ET.fromstring(svg_path.read_text()) - ids = [group.attrib["id"] for group in root.iter(f"{SVG_NS}g")] - hrefs = [use.attrib[XLINK_HREF] for use in root.iter(f"{SVG_NS}use")] - - assert ids.count("a_b") == 1 - assert len(set(ids)) == len(ids) - assert "#a_b" in hrefs - assert "#a_b_2" in hrefs diff --git a/masque/test/test_utils.py b/masque/test/test_utils.py index 0511a24..45e347e 100644 --- a/masque/test/test_utils.py +++ b/masque/test/test_utils.py @@ -1,14 +1,8 @@ -from pathlib import Path - import numpy from numpy.testing import assert_equal, assert_allclose from numpy import pi -import pytest from ..utils import remove_duplicate_vertices, remove_colinear_vertices, poly_contains_points, rotation_matrix_2d, apply_transforms, DeferredDict -from ..file.utils import tmpfile -from ..utils.curves import bezier -from ..error import PatternError def test_remove_duplicate_vertices() -> None: @@ -94,19 +88,6 @@ def test_apply_transforms_advanced() -> None: assert_allclose(combined[0], [0, 10, pi / 2, 1, 1], atol=1e-10) -def test_bezier_validates_weight_length() -> None: - with pytest.raises(PatternError, match='one entry per control point'): - bezier([[0, 0], [1, 1]], [0, 0.5, 1], weights=[1]) - - with pytest.raises(PatternError, match='one entry per control point'): - bezier([[0, 0], [1, 1]], [0, 0.5, 1], weights=[1, 2, 3]) - - -def test_bezier_accepts_exact_weight_count() -> None: - samples = bezier([[0, 0], [1, 1]], [0, 0.5, 1], weights=[1, 2]) - assert_allclose(samples, [[0, 0], [2 / 3, 2 / 3], [1, 1]], atol=1e-10) - - def test_deferred_dict_accessors_resolve_values_once() -> None: calls = 0 @@ -123,48 +104,3 @@ def test_deferred_dict_accessors_resolve_values_once() -> None: assert list(deferred.values()) == [7] assert list(deferred.items()) == [("x", 7)] assert calls == 1 - - -def test_deferred_dict_mutating_accessors_preserve_value_semantics() -> None: - calls = 0 - - def make_value() -> int: - nonlocal calls - calls += 1 - return 7 - - deferred = DeferredDict[str, int]() - - assert deferred.setdefault("x", 5) == 5 - assert deferred["x"] == 5 - - assert deferred.setdefault("y", make_value) == 7 - assert deferred["y"] == 7 - assert calls == 1 - - copy_deferred = deferred.copy() - assert isinstance(copy_deferred, DeferredDict) - assert copy_deferred["x"] == 5 - assert copy_deferred["y"] == 7 - assert calls == 1 - - assert deferred.pop("x") == 5 - assert deferred.pop("missing", 9) == 9 - assert deferred.popitem() == ("y", 7) - - -def test_tmpfile_cleans_up_on_exception(tmp_path: Path) -> None: - target = tmp_path / "out.txt" - temp_path = None - - try: - with tmpfile(target) as stream: - temp_path = Path(stream.name) - stream.write(b"hello") - raise RuntimeError("boom") - except RuntimeError: - pass - - assert temp_path is not None - assert not target.exists() - assert not temp_path.exists() diff --git a/masque/utils/curves.py b/masque/utils/curves.py index 3a7671b..2348678 100644 --- a/masque/utils/curves.py +++ b/masque/utils/curves.py @@ -2,8 +2,6 @@ import numpy from numpy.typing import ArrayLike, NDArray from numpy import pi -from ..error import PatternError - try: from numpy import trapezoid except ImportError: @@ -33,11 +31,6 @@ def bezier( tt = numpy.asarray(tt) nn = nodes.shape[0] weights = numpy.ones(nn) if weights is None else numpy.asarray(weights) - if weights.ndim != 1 or weights.shape[0] != nn: - raise PatternError( - f'weights must be a 1D array with one entry per control point; ' - f'got shape {weights.shape} for {nn} control points' - ) with numpy.errstate(divide='ignore'): umul = (tt / (1 - tt)).clip(max=1) diff --git a/masque/utils/deferreddict.py b/masque/utils/deferreddict.py index 70893c0..def9b10 100644 --- a/masque/utils/deferreddict.py +++ b/masque/utils/deferreddict.py @@ -5,7 +5,6 @@ from functools import lru_cache Key = TypeVar('Key') Value = TypeVar('Value') -_MISSING = object() class DeferredDict(dict, Generic[Key, Value]): @@ -47,15 +46,6 @@ class DeferredDict(dict, Generic[Key, Value]): return default return self[key] - def setdefault(self, key: Key, default: Value | Callable[[], Value] | None = None) -> Value | None: - if key in self: - return self[key] - if callable(default): - self[key] = default - else: - self.set_const(key, default) - return self[key] - def items(self) -> Iterator[tuple[Key, Value]]: for key in self.keys(): yield key, self[key] @@ -75,25 +65,6 @@ class DeferredDict(dict, Generic[Key, Value]): else: self.set_const(k, v) - def pop(self, key: Key, default: Value | object = _MISSING) -> Value: - if key in self: - value = self[key] - dict.pop(self, key) - return value - if default is _MISSING: - raise KeyError(key) - return default # type: ignore[return-value] - - def popitem(self) -> tuple[Key, Value]: - key, value_func = dict.popitem(self) - return key, value_func() - - def copy(self) -> 'DeferredDict[Key, Value]': - new = DeferredDict[Key, Value]() - for key in self.keys(): - dict.__setitem__(new, key, dict.__getitem__(self, key)) - return new - def __repr__(self) -> str: return '' diff --git a/masque/utils/ports2data.py b/masque/utils/ports2data.py index 44a0ec3..c7f42e1 100644 --- a/masque/utils/ports2data.py +++ b/masque/utils/ports2data.py @@ -122,7 +122,6 @@ def data_to_ports( if not found_ports: return pattern - imported_ports: dict[str, Port] = {} for target, refs in pattern.refs.items(): if target is None: continue @@ -134,14 +133,9 @@ def data_to_ports( if not aa.ports: break - if ref.repetition is not None: - logger.warning(f'Pattern {name if name else pattern} has repeated ref to {target!r}; ' - 'data_to_ports() is importing only the base instance ports') aa.apply_ref_transform(ref) - Pattern(ports={**pattern.ports, **imported_ports}).check_ports(other_names=aa.ports.keys()) - imported_ports.update(aa.ports) - - pattern.ports.update(imported_ports) + pattern.check_ports(other_names=aa.ports.keys()) + pattern.ports.update(aa.ports) return pattern @@ -184,14 +178,7 @@ def data_to_ports_flat( name, property_string = label.string.split(':', 1) properties = property_string.split() ptype = properties[0] if len(properties) > 0 else 'unk' - if len(properties) > 1: - try: - angle_deg = float(properties[1]) - except ValueError: - logger.warning(f'Invalid port label "{label.string}" in pattern "{pstr}" (bad angle)') - continue - else: - angle_deg = numpy.inf + angle_deg = float(properties[1]) if len(properties) > 1 else numpy.inf xy = label.offset angle = numpy.deg2rad(angle_deg) if numpy.isfinite(angle_deg) else None @@ -203,3 +190,4 @@ def data_to_ports_flat( pattern.ports.update(local_ports) return pattern +