Compare commits
17 commits
9767ee4e62
...
75a9114709
| Author | SHA1 | Date | |
|---|---|---|---|
| 75a9114709 | |||
| df578d7764 | |||
| 786716fc62 | |||
| a82365ec8c | |||
| 28be89f047 | |||
| afc49f945d | |||
| ce46cc18dc | |||
| 7c50f95fde | |||
| ae314cce93 | |||
| 09a95a6608 | |||
| fbe138d443 | |||
| 4b416745da | |||
| 0830dce50c | |||
| ac87179da2 | |||
| f0eea0382b | |||
| 0c9b435e94 | |||
| f461222852 |
20 changed files with 479 additions and 17 deletions
|
|
@ -637,6 +637,7 @@ def check_valid_names(
|
|||
max_length: Max allowed length
|
||||
|
||||
"""
|
||||
names = tuple(names)
|
||||
allowed_chars = set(string.ascii_letters + string.digits + '_?$')
|
||||
|
||||
bad_chars = [
|
||||
|
|
|
|||
|
|
@ -30,6 +30,21 @@ def _ref_to_svg_transform(ref) -> str:
|
|||
return f'matrix({a:g} {b:g} {c:g} {d:g} {e:g} {f:g})'
|
||||
|
||||
|
||||
def _make_svg_ids(names: Mapping[str, Pattern]) -> dict[str, str]:
|
||||
svg_ids: dict[str, str] = {}
|
||||
seen_ids: set[str] = set()
|
||||
for name in names:
|
||||
base_id = mangle_name(name)
|
||||
svg_id = base_id
|
||||
suffix = 1
|
||||
while svg_id in seen_ids:
|
||||
suffix += 1
|
||||
svg_id = f'{base_id}_{suffix}'
|
||||
seen_ids.add(svg_id)
|
||||
svg_ids[name] = svg_id
|
||||
return svg_ids
|
||||
|
||||
|
||||
def writefile(
|
||||
library: Mapping[str, Pattern],
|
||||
top: str,
|
||||
|
|
@ -81,10 +96,11 @@ def writefile(
|
|||
# Create file
|
||||
svg = svgwrite.Drawing(filename, profile='full', viewBox=viewbox_string,
|
||||
debug=(not custom_attributes))
|
||||
svg_ids = _make_svg_ids(library)
|
||||
|
||||
# Now create a group for each pattern and add in any Boundary and Use elements
|
||||
for name, pat in library.items():
|
||||
svg_group = svg.g(id=mangle_name(name), fill='blue', stroke='red')
|
||||
svg_group = svg.g(id=svg_ids[name], fill='blue', stroke='red')
|
||||
|
||||
for layer, shapes in pat.shapes.items():
|
||||
for shape in shapes:
|
||||
|
|
@ -123,11 +139,11 @@ def writefile(
|
|||
continue
|
||||
for ref in refs:
|
||||
transform = _ref_to_svg_transform(ref)
|
||||
use = svg.use(href='#' + mangle_name(target), transform=transform)
|
||||
use = svg.use(href='#' + svg_ids[target], transform=transform)
|
||||
svg_group.add(use)
|
||||
|
||||
svg.defs.add(svg_group)
|
||||
svg.add(svg.use(href='#' + mangle_name(top)))
|
||||
svg.add(svg.use(href='#' + svg_ids[top]))
|
||||
svg.save()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -33,6 +33,12 @@ def preflight(
|
|||
Run a standard set of useful operations and checks, usually done immediately prior
|
||||
to writing to a file (or immediately after reading).
|
||||
|
||||
Note that this helper is not copy-isolating. When `sort=True`, it constructs a new
|
||||
`Library` wrapper around the same `Pattern` objects after sorting them in place, so
|
||||
later mutating preflight steps such as `prune_empty_patterns` and
|
||||
`wrap_repeated_shapes` may still mutate caller-owned patterns. Callers that need
|
||||
isolation should deep-copy the library before calling `preflight()`.
|
||||
|
||||
Args:
|
||||
sort: Whether to sort the patterns based on their names, and optionaly sort the pattern contents.
|
||||
Default True. Useful for reproducible builds.
|
||||
|
|
@ -145,7 +151,11 @@ def tmpfile(path: str | pathlib.Path) -> Iterator[IO[bytes]]:
|
|||
path = pathlib.Path(path)
|
||||
suffixes = ''.join(path.suffixes)
|
||||
with tempfile.NamedTemporaryFile(suffix=suffixes, delete=False) as tmp_stream:
|
||||
yield tmp_stream
|
||||
try:
|
||||
yield tmp_stream
|
||||
except Exception:
|
||||
pathlib.Path(tmp_stream.name).unlink(missing_ok=True)
|
||||
raise
|
||||
|
||||
try:
|
||||
shutil.move(tmp_stream.name, path)
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ import copy
|
|||
from pprint import pformat
|
||||
from collections import defaultdict
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from graphlib import TopologicalSorter
|
||||
from graphlib import TopologicalSorter, CycleError
|
||||
|
||||
import numpy
|
||||
from numpy.typing import ArrayLike, NDArray
|
||||
|
|
@ -538,6 +538,7 @@ class ILibraryView(Mapping[str, 'Pattern'], metaclass=ABCMeta):
|
|||
raise LibraryError('visit_* functions returned a new `Pattern` object'
|
||||
' but no top-level name was provided in `hierarchy`')
|
||||
|
||||
del cast('ILibrary', self)[name]
|
||||
cast('ILibrary', self)[name] = pattern
|
||||
|
||||
return self
|
||||
|
|
@ -617,7 +618,13 @@ class ILibraryView(Mapping[str, 'Pattern'], metaclass=ABCMeta):
|
|||
Return:
|
||||
Topologically sorted list of pattern names.
|
||||
"""
|
||||
return cast('list[str]', list(TopologicalSorter(self.child_graph(dangling=dangling)).static_order()))
|
||||
try:
|
||||
return cast('list[str]', list(TopologicalSorter(self.child_graph(dangling=dangling)).static_order()))
|
||||
except CycleError as exc:
|
||||
cycle = exc.args[1] if len(exc.args) > 1 else None
|
||||
if cycle is None:
|
||||
raise LibraryError('Cycle found while building child order') from exc
|
||||
raise LibraryError(f'Cycle found while building child order: {cycle}') from exc
|
||||
|
||||
def find_refs_local(
|
||||
self,
|
||||
|
|
@ -916,8 +923,8 @@ class ILibrary(ILibraryView, MutableMapping[str, 'Pattern'], metaclass=ABCMeta):
|
|||
(default).
|
||||
|
||||
Returns:
|
||||
A mapping of `{old_name: new_name}` for all `old_name`s in `other`. Unchanged
|
||||
names map to themselves.
|
||||
A mapping of `{old_name: new_name}` for all names in `other` which were
|
||||
renamed while being added. Unchanged names are omitted.
|
||||
|
||||
Raises:
|
||||
`LibraryError` if a duplicate name is encountered even after applying `rename_theirs()`.
|
||||
|
|
@ -926,8 +933,13 @@ class ILibrary(ILibraryView, MutableMapping[str, 'Pattern'], metaclass=ABCMeta):
|
|||
duplicates = set(self.keys()) & set(other.keys())
|
||||
|
||||
if not duplicates:
|
||||
for key in other:
|
||||
self._merge(key, other, key)
|
||||
if mutate_other:
|
||||
temp = other
|
||||
else:
|
||||
temp = Library(copy.deepcopy(dict(other)))
|
||||
|
||||
for key in temp:
|
||||
self._merge(key, temp, key)
|
||||
return {}
|
||||
|
||||
if mutate_other:
|
||||
|
|
|
|||
|
|
@ -510,10 +510,19 @@ class PortList(metaclass=ABCMeta):
|
|||
if missing_invals:
|
||||
raise PortError(f'`map_in` values not present in other device: {missing_invals}')
|
||||
|
||||
map_in_counts = Counter(map_in.values())
|
||||
conflicts_in = {kk for kk, vv in map_in_counts.items() if vv > 1}
|
||||
if conflicts_in:
|
||||
raise PortError(f'Duplicate values in `map_in`: {conflicts_in}')
|
||||
|
||||
missing_outkeys = set(map_out.keys()) - other
|
||||
if missing_outkeys:
|
||||
raise PortError(f'`map_out` keys not present in other device: {missing_outkeys}')
|
||||
|
||||
connected_outkeys = set(map_out.keys()) & set(map_in.values())
|
||||
if connected_outkeys:
|
||||
raise PortError(f'`map_out` keys conflict with connected ports: {connected_outkeys}')
|
||||
|
||||
orig_remaining = set(self.ports.keys()) - set(map_in.keys())
|
||||
other_remaining = other - set(map_out.keys()) - set(map_in.values())
|
||||
mapped_vals = set(map_out.values())
|
||||
|
|
|
|||
|
|
@ -236,7 +236,10 @@ class Ref(
|
|||
bounds = numpy.vstack((numpy.min(corners, axis=0),
|
||||
numpy.max(corners, axis=0))) * self.scale + [self.offset]
|
||||
return bounds
|
||||
return self.as_pattern(pattern=pattern).get_bounds(library)
|
||||
|
||||
single_ref = self.deepcopy()
|
||||
single_ref.repetition = None
|
||||
return single_ref.as_pattern(pattern=pattern).get_bounds(library)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
rotation = f' r{numpy.rad2deg(self.rotation):g}' if self.rotation != 0 else ''
|
||||
|
|
|
|||
|
|
@ -184,6 +184,8 @@ class Grid(Repetition):
|
|||
def a_count(self, val: int) -> None:
|
||||
if val != int(val):
|
||||
raise PatternError('a_count must be convertable to an int!')
|
||||
if int(val) < 1:
|
||||
raise PatternError(f'Repetition has too-small a_count: {val}')
|
||||
self._a_count = int(val)
|
||||
|
||||
# b_count property
|
||||
|
|
@ -195,6 +197,8 @@ class Grid(Repetition):
|
|||
def b_count(self, val: int) -> None:
|
||||
if val != int(val):
|
||||
raise PatternError('b_count must be convertable to an int!')
|
||||
if int(val) < 1:
|
||||
raise PatternError(f'Repetition has too-small b_count: {val}')
|
||||
self._b_count = int(val)
|
||||
|
||||
@property
|
||||
|
|
@ -325,7 +329,22 @@ class Arbitrary(Repetition):
|
|||
|
||||
@displacements.setter
|
||||
def displacements(self, val: ArrayLike) -> None:
|
||||
vala = numpy.array(val, dtype=float)
|
||||
try:
|
||||
vala = numpy.array(val, dtype=float)
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise PatternError('displacements must be convertible to an Nx2 ndarray') from exc
|
||||
|
||||
if vala.size == 0:
|
||||
self._displacements = numpy.empty((0, 2), dtype=float)
|
||||
return
|
||||
|
||||
if vala.ndim == 1:
|
||||
if vala.size != 2:
|
||||
raise PatternError('displacements must be convertible to an Nx2 ndarray')
|
||||
vala = vala.reshape(1, 2)
|
||||
elif vala.ndim != 2 or vala.shape[1] != 2:
|
||||
raise PatternError('displacements must be convertible to an Nx2 ndarray')
|
||||
|
||||
order = numpy.lexsort(vala.T[::-1]) # sortrows
|
||||
self._displacements = vala[order]
|
||||
|
||||
|
|
|
|||
|
|
@ -219,7 +219,7 @@ class PolyCollection(Shape):
|
|||
(offset, scale / norm_value, rotation, False),
|
||||
lambda: PolyCollection(
|
||||
vertex_lists=rotated_vertices * norm_value,
|
||||
vertex_offsets=self._vertex_offsets,
|
||||
vertex_offsets=self._vertex_offsets.copy(),
|
||||
),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
from pathlib import Path
|
||||
from typing import cast
|
||||
import numpy
|
||||
import pytest
|
||||
from numpy.testing import assert_equal, assert_allclose
|
||||
|
||||
from ..error import LibraryError
|
||||
from ..pattern import Pattern
|
||||
from ..library import Library
|
||||
from ..file import gdsii
|
||||
|
|
@ -69,3 +71,10 @@ def test_gdsii_annotations(tmp_path: Path) -> None:
|
|||
read_ann = read_lib["cell"].shapes[(1, 0)][0].annotations
|
||||
assert read_ann is not None
|
||||
assert read_ann["1"] == ["hello"]
|
||||
|
||||
|
||||
def test_gdsii_check_valid_names_validates_generator_lengths() -> None:
|
||||
names = (name for name in ("a" * 40,))
|
||||
|
||||
with pytest.raises(LibraryError, match="invalid names"):
|
||||
gdsii.check_valid_names(names)
|
||||
|
|
|
|||
|
|
@ -221,6 +221,77 @@ def test_library_rename() -> None:
|
|||
assert "old" not in lib["parent"].refs
|
||||
|
||||
|
||||
def test_library_dfs_can_replace_existing_patterns() -> None:
|
||||
lib = Library()
|
||||
child = Pattern()
|
||||
lib["child"] = child
|
||||
top = Pattern()
|
||||
top.ref("child")
|
||||
lib["top"] = top
|
||||
|
||||
replacement_top = Pattern(ports={"T": Port((1, 2), 0)})
|
||||
replacement_child = Pattern(ports={"C": Port((3, 4), 0)})
|
||||
|
||||
def visit_after(pattern: Pattern, hierarchy: tuple[str | None, ...], **kwargs) -> Pattern: # noqa: ARG001
|
||||
if hierarchy[-1] == "child":
|
||||
return replacement_child
|
||||
if hierarchy[-1] == "top":
|
||||
return replacement_top
|
||||
return pattern
|
||||
|
||||
lib.dfs(lib["top"], visit_after=visit_after, hierarchy=("top",), transform=True)
|
||||
|
||||
assert lib["top"] is replacement_top
|
||||
assert lib["child"] is replacement_child
|
||||
|
||||
|
||||
def test_lazy_library_dfs_can_replace_existing_patterns() -> None:
|
||||
lib = LazyLibrary()
|
||||
lib["child"] = lambda: Pattern()
|
||||
lib["top"] = lambda: Pattern(refs={"child": []})
|
||||
|
||||
top = lib["top"]
|
||||
top.ref("child")
|
||||
|
||||
replacement_top = Pattern(ports={"T": Port((1, 2), 0)})
|
||||
replacement_child = Pattern(ports={"C": Port((3, 4), 0)})
|
||||
|
||||
def visit_after(pattern: Pattern, hierarchy: tuple[str | None, ...], **kwargs) -> Pattern: # noqa: ARG001
|
||||
if hierarchy[-1] == "child":
|
||||
return replacement_child
|
||||
if hierarchy[-1] == "top":
|
||||
return replacement_top
|
||||
return pattern
|
||||
|
||||
lib.dfs(top, visit_after=visit_after, hierarchy=("top",), transform=True)
|
||||
|
||||
assert lib["top"] is replacement_top
|
||||
assert lib["child"] is replacement_child
|
||||
|
||||
|
||||
def test_library_add_no_duplicates_respects_mutate_other_false() -> None:
|
||||
src_pat = Pattern(ports={"A": Port((0, 0), 0)})
|
||||
lib = Library({"a": Pattern()})
|
||||
|
||||
lib.add({"b": src_pat}, mutate_other=False)
|
||||
|
||||
assert lib["b"] is not src_pat
|
||||
lib["b"].ports["A"].offset[0] = 123
|
||||
assert tuple(src_pat.ports["A"].offset) == (0.0, 0.0)
|
||||
|
||||
|
||||
def test_library_add_returns_only_renamed_entries() -> None:
|
||||
lib = Library({"a": Pattern(), "_shape": Pattern()})
|
||||
|
||||
assert lib.add({"b": Pattern(), "c": Pattern()}, mutate_other=False) == {}
|
||||
|
||||
rename_map = lib.add({"_shape": Pattern(), "keep": Pattern()}, mutate_other=False)
|
||||
|
||||
assert set(rename_map) == {"_shape"}
|
||||
assert rename_map["_shape"] != "_shape"
|
||||
assert "keep" not in rename_map
|
||||
|
||||
|
||||
def test_library_subtree() -> None:
|
||||
lib = Library()
|
||||
lib["a"] = Pattern()
|
||||
|
|
@ -234,6 +305,26 @@ def test_library_subtree() -> None:
|
|||
assert "c" not in sub
|
||||
|
||||
|
||||
def test_library_child_order_cycle_raises_library_error() -> None:
|
||||
lib = Library()
|
||||
lib["a"] = Pattern()
|
||||
lib["a"].ref("b")
|
||||
lib["b"] = Pattern()
|
||||
lib["b"].ref("a")
|
||||
|
||||
with pytest.raises(LibraryError, match="Cycle found while building child order"):
|
||||
lib.child_order()
|
||||
|
||||
|
||||
def test_library_find_refs_global_cycle_raises_library_error() -> None:
|
||||
lib = Library()
|
||||
lib["a"] = Pattern()
|
||||
lib["a"].ref("a")
|
||||
|
||||
with pytest.raises(LibraryError, match="Cycle found while building child order"):
|
||||
lib.find_refs_global("a")
|
||||
|
||||
|
||||
def test_library_get_name() -> None:
|
||||
lib = Library()
|
||||
lib["cell"] = Pattern()
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from numpy import pi
|
|||
|
||||
from ..ports import Port, PortList
|
||||
from ..error import PortError
|
||||
from ..pattern import Pattern
|
||||
|
||||
|
||||
def test_port_init() -> None:
|
||||
|
|
@ -227,3 +228,32 @@ def test_port_list_plugged_mismatch() -> None:
|
|||
pl = MyPorts()
|
||||
with pytest.raises(PortError):
|
||||
pl.plugged({"A": "B"})
|
||||
|
||||
|
||||
def test_port_list_check_ports_duplicate_map_in_values_raise() -> None:
|
||||
class MyPorts(PortList):
|
||||
def __init__(self) -> None:
|
||||
self._ports = {"A": Port((0, 0), 0), "B": Port((0, 0), 0)}
|
||||
|
||||
@property
|
||||
def ports(self) -> dict[str, Port]:
|
||||
return self._ports
|
||||
|
||||
@ports.setter
|
||||
def ports(self, val: dict[str, Port]) -> None:
|
||||
self._ports = val
|
||||
|
||||
pl = MyPorts()
|
||||
with pytest.raises(PortError, match="Duplicate values in `map_in`"):
|
||||
pl.check_ports({"X", "Y"}, map_in={"A": "X", "B": "X"})
|
||||
assert set(pl.ports) == {"A", "B"}
|
||||
|
||||
|
||||
def test_pattern_plug_rejects_map_out_on_connected_ports_atomically() -> None:
|
||||
host = Pattern(ports={"A": Port((0, 0), 0)})
|
||||
other = Pattern(ports={"X": Port((0, 0), pi), "Y": Port((5, 0), 0)})
|
||||
|
||||
with pytest.raises(PortError, match="`map_out` keys conflict with connected ports"):
|
||||
host.plug(other, {"A": "X"}, map_out={"X": "renamed", "Y": "out"}, append=True)
|
||||
|
||||
assert set(host.ports) == {"A"}
|
||||
|
|
|
|||
|
|
@ -1,10 +1,13 @@
|
|||
import numpy
|
||||
import pytest
|
||||
from numpy.testing import assert_allclose
|
||||
|
||||
from ..utils.ports2data import ports_to_data, data_to_ports
|
||||
from ..pattern import Pattern
|
||||
from ..ports import Port
|
||||
from ..library import Library
|
||||
from ..error import PortError
|
||||
from ..repetition import Grid
|
||||
|
||||
|
||||
def test_ports2data_roundtrip() -> None:
|
||||
|
|
@ -74,3 +77,56 @@ def test_data_to_ports_hierarchical_scaled_ref() -> None:
|
|||
assert_allclose(parent.ports["A"].offset, [100, 110], atol=1e-10)
|
||||
assert parent.ports["A"].rotation is not None
|
||||
assert_allclose(parent.ports["A"].rotation, numpy.pi / 2, atol=1e-10)
|
||||
|
||||
|
||||
def test_data_to_ports_hierarchical_repeated_ref_warns_and_keeps_best_effort(
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
lib = Library()
|
||||
|
||||
child = Pattern()
|
||||
layer = (10, 0)
|
||||
child.label(layer=layer, string="A:type1 0", offset=(5, 0))
|
||||
lib["child"] = child
|
||||
|
||||
parent = Pattern()
|
||||
parent.ref("child", repetition=Grid(a_vector=(100, 0), a_count=3))
|
||||
|
||||
caplog.set_level("WARNING")
|
||||
data_to_ports([layer], lib, parent, max_depth=1)
|
||||
|
||||
assert "A" in parent.ports
|
||||
assert_allclose(parent.ports["A"].offset, [5, 0], atol=1e-10)
|
||||
assert any("importing only the base instance ports" in record.message for record in caplog.records)
|
||||
|
||||
|
||||
def test_data_to_ports_hierarchical_collision_is_atomic() -> None:
|
||||
lib = Library()
|
||||
|
||||
child = Pattern()
|
||||
layer = (10, 0)
|
||||
child.label(layer=layer, string="A:type1 0", offset=(5, 0))
|
||||
lib["child"] = child
|
||||
|
||||
parent = Pattern()
|
||||
parent.ref("child", offset=(0, 0))
|
||||
parent.ref("child", offset=(10, 0))
|
||||
|
||||
with pytest.raises(PortError, match="Device ports conflict with existing ports"):
|
||||
data_to_ports([layer], lib, parent, max_depth=1)
|
||||
|
||||
assert not parent.ports
|
||||
|
||||
|
||||
def test_data_to_ports_flat_bad_angle_warns_and_skips(
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
layer = (10, 0)
|
||||
pat = Pattern()
|
||||
pat.label(layer=layer, string="A:type1 nope", offset=(5, 0))
|
||||
|
||||
caplog.set_level("WARNING")
|
||||
data_to_ports([layer], {}, pat)
|
||||
|
||||
assert not pat.ports
|
||||
assert any('bad angle' in record.message for record in caplog.records)
|
||||
|
|
|
|||
|
|
@ -64,6 +64,22 @@ def test_ref_get_bounds() -> None:
|
|||
assert_equal(bounds, [[10, 10], [20, 20]])
|
||||
|
||||
|
||||
def test_ref_get_bounds_single_ignores_repetition_for_non_manhattan_rotation() -> None:
|
||||
sub_pat = Pattern()
|
||||
sub_pat.rect((1, 0), xmin=0, xmax=1, ymin=0, ymax=2)
|
||||
|
||||
rep = Grid(a_vector=(5, 0), b_vector=(0, 7), a_count=3, b_count=2)
|
||||
ref = Ref(offset=(10, 20), rotation=pi / 4, repetition=rep)
|
||||
|
||||
bounds = ref.get_bounds_single(sub_pat)
|
||||
repeated_bounds = ref.get_bounds(sub_pat)
|
||||
|
||||
assert bounds is not None
|
||||
assert repeated_bounds is not None
|
||||
assert repeated_bounds[1, 0] > bounds[1, 0]
|
||||
assert repeated_bounds[1, 1] > bounds[1, 1]
|
||||
|
||||
|
||||
def test_ref_copy() -> None:
|
||||
ref1 = Ref(offset=(1, 2), rotation=0.5, annotations={"a": [1]})
|
||||
ref2 = ref1.copy()
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
import pytest
|
||||
from numpy.testing import assert_equal, assert_allclose
|
||||
from numpy import pi
|
||||
|
||||
from ..repetition import Grid, Arbitrary
|
||||
from ..error import PatternError
|
||||
|
||||
|
||||
def test_grid_displacements() -> None:
|
||||
|
|
@ -51,6 +53,30 @@ def test_arbitrary_transform() -> None:
|
|||
assert_allclose(arb.displacements, [[0, -10]], atol=1e-10)
|
||||
|
||||
|
||||
def test_arbitrary_empty_repetition_is_allowed() -> None:
|
||||
arb = Arbitrary([])
|
||||
assert arb.displacements.shape == (0, 2)
|
||||
assert arb.get_bounds() is None
|
||||
|
||||
|
||||
def test_arbitrary_rejects_non_nx2_displacements() -> None:
|
||||
for displacements in ([[1], [2]], [[1, 2, 3]], [1, 2, 3]):
|
||||
with pytest.raises(PatternError, match='displacements must be convertible to an Nx2 ndarray'):
|
||||
Arbitrary(displacements)
|
||||
|
||||
|
||||
def test_grid_count_setters_reject_nonpositive_values() -> None:
|
||||
for attr, value, message in (
|
||||
('a_count', 0, 'a_count'),
|
||||
('a_count', -1, 'a_count'),
|
||||
('b_count', 0, 'b_count'),
|
||||
('b_count', -1, 'b_count'),
|
||||
):
|
||||
grid = Grid(a_vector=(10, 0), b_vector=(0, 5), a_count=2, b_count=2)
|
||||
with pytest.raises(PatternError, match=message):
|
||||
setattr(grid, attr, value)
|
||||
|
||||
|
||||
def test_repetition_less_equal_includes_equality() -> None:
|
||||
grid_a = Grid(a_vector=(10, 0), a_count=2)
|
||||
grid_b = Grid(a_vector=(10, 0), a_count=2)
|
||||
|
|
|
|||
|
|
@ -212,3 +212,27 @@ def test_poly_collection_valid() -> None:
|
|||
assert len(sorted_shapes) == 4
|
||||
# Just verify it doesn't crash and is stable
|
||||
assert sorted(sorted_shapes) == sorted_shapes
|
||||
|
||||
|
||||
def test_poly_collection_normalized_form_reconstruction_is_independent() -> None:
|
||||
pc = PolyCollection([[0, 0], [1, 0], [0, 1]], [0])
|
||||
_intrinsic, _extrinsic, rebuild = pc.normalized_form(1)
|
||||
|
||||
clone = rebuild()
|
||||
clone.vertex_offsets[:] = [5]
|
||||
|
||||
assert_equal(pc.vertex_offsets, [0])
|
||||
assert_equal(clone.vertex_offsets, [5])
|
||||
|
||||
|
||||
def test_poly_collection_normalized_form_rebuilds_independent_clones() -> None:
|
||||
pc = PolyCollection([[0, 0], [1, 0], [0, 1]], [0])
|
||||
_intrinsic, _extrinsic, rebuild = pc.normalized_form(1)
|
||||
|
||||
first = rebuild()
|
||||
second = rebuild()
|
||||
first.vertex_offsets[:] = [7]
|
||||
|
||||
assert_equal(first.vertex_offsets, [7])
|
||||
assert_equal(second.vertex_offsets, [0])
|
||||
assert_equal(pc.vertex_offsets, [0])
|
||||
|
|
|
|||
|
|
@ -68,3 +68,31 @@ def test_svg_ref_mirroring_changes_affine_transform(tmp_path: Path) -> None:
|
|||
|
||||
assert_allclose(plain_transform, (0, 2, -2, 0, 3, 4), atol=1e-10)
|
||||
assert_allclose(mirrored_transform, (0, 2, 2, 0, 3, 4), atol=1e-10)
|
||||
|
||||
|
||||
def test_svg_uses_unique_ids_for_colliding_mangled_names(tmp_path: Path) -> None:
|
||||
lib = Library()
|
||||
first = Pattern()
|
||||
first.polygon("1", vertices=[[0, 0], [1, 0], [0, 1]])
|
||||
lib["a b"] = first
|
||||
|
||||
second = Pattern()
|
||||
second.polygon("1", vertices=[[0, 0], [2, 0], [0, 2]])
|
||||
lib["a-b"] = second
|
||||
|
||||
top = Pattern()
|
||||
top.ref("a b")
|
||||
top.ref("a-b", offset=(5, 0))
|
||||
lib["top"] = top
|
||||
|
||||
svg_path = tmp_path / "colliding_ids.svg"
|
||||
svg.writefile(lib, "top", str(svg_path))
|
||||
|
||||
root = ET.fromstring(svg_path.read_text())
|
||||
ids = [group.attrib["id"] for group in root.iter(f"{SVG_NS}g")]
|
||||
hrefs = [use.attrib[XLINK_HREF] for use in root.iter(f"{SVG_NS}use")]
|
||||
|
||||
assert ids.count("a_b") == 1
|
||||
assert len(set(ids)) == len(ids)
|
||||
assert "#a_b" in hrefs
|
||||
assert "#a_b_2" in hrefs
|
||||
|
|
|
|||
|
|
@ -1,8 +1,14 @@
|
|||
from pathlib import Path
|
||||
|
||||
import numpy
|
||||
from numpy.testing import assert_equal, assert_allclose
|
||||
from numpy import pi
|
||||
import pytest
|
||||
|
||||
from ..utils import remove_duplicate_vertices, remove_colinear_vertices, poly_contains_points, rotation_matrix_2d, apply_transforms, DeferredDict
|
||||
from ..file.utils import tmpfile
|
||||
from ..utils.curves import bezier
|
||||
from ..error import PatternError
|
||||
|
||||
|
||||
def test_remove_duplicate_vertices() -> None:
|
||||
|
|
@ -88,6 +94,19 @@ def test_apply_transforms_advanced() -> None:
|
|||
assert_allclose(combined[0], [0, 10, pi / 2, 1, 1], atol=1e-10)
|
||||
|
||||
|
||||
def test_bezier_validates_weight_length() -> None:
|
||||
with pytest.raises(PatternError, match='one entry per control point'):
|
||||
bezier([[0, 0], [1, 1]], [0, 0.5, 1], weights=[1])
|
||||
|
||||
with pytest.raises(PatternError, match='one entry per control point'):
|
||||
bezier([[0, 0], [1, 1]], [0, 0.5, 1], weights=[1, 2, 3])
|
||||
|
||||
|
||||
def test_bezier_accepts_exact_weight_count() -> None:
|
||||
samples = bezier([[0, 0], [1, 1]], [0, 0.5, 1], weights=[1, 2])
|
||||
assert_allclose(samples, [[0, 0], [2 / 3, 2 / 3], [1, 1]], atol=1e-10)
|
||||
|
||||
|
||||
def test_deferred_dict_accessors_resolve_values_once() -> None:
|
||||
calls = 0
|
||||
|
||||
|
|
@ -104,3 +123,48 @@ def test_deferred_dict_accessors_resolve_values_once() -> None:
|
|||
assert list(deferred.values()) == [7]
|
||||
assert list(deferred.items()) == [("x", 7)]
|
||||
assert calls == 1
|
||||
|
||||
|
||||
def test_deferred_dict_mutating_accessors_preserve_value_semantics() -> None:
|
||||
calls = 0
|
||||
|
||||
def make_value() -> int:
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
return 7
|
||||
|
||||
deferred = DeferredDict[str, int]()
|
||||
|
||||
assert deferred.setdefault("x", 5) == 5
|
||||
assert deferred["x"] == 5
|
||||
|
||||
assert deferred.setdefault("y", make_value) == 7
|
||||
assert deferred["y"] == 7
|
||||
assert calls == 1
|
||||
|
||||
copy_deferred = deferred.copy()
|
||||
assert isinstance(copy_deferred, DeferredDict)
|
||||
assert copy_deferred["x"] == 5
|
||||
assert copy_deferred["y"] == 7
|
||||
assert calls == 1
|
||||
|
||||
assert deferred.pop("x") == 5
|
||||
assert deferred.pop("missing", 9) == 9
|
||||
assert deferred.popitem() == ("y", 7)
|
||||
|
||||
|
||||
def test_tmpfile_cleans_up_on_exception(tmp_path: Path) -> None:
|
||||
target = tmp_path / "out.txt"
|
||||
temp_path = None
|
||||
|
||||
try:
|
||||
with tmpfile(target) as stream:
|
||||
temp_path = Path(stream.name)
|
||||
stream.write(b"hello")
|
||||
raise RuntimeError("boom")
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
assert temp_path is not None
|
||||
assert not target.exists()
|
||||
assert not temp_path.exists()
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ import numpy
|
|||
from numpy.typing import ArrayLike, NDArray
|
||||
from numpy import pi
|
||||
|
||||
from ..error import PatternError
|
||||
|
||||
try:
|
||||
from numpy import trapezoid
|
||||
except ImportError:
|
||||
|
|
@ -31,6 +33,11 @@ def bezier(
|
|||
tt = numpy.asarray(tt)
|
||||
nn = nodes.shape[0]
|
||||
weights = numpy.ones(nn) if weights is None else numpy.asarray(weights)
|
||||
if weights.ndim != 1 or weights.shape[0] != nn:
|
||||
raise PatternError(
|
||||
f'weights must be a 1D array with one entry per control point; '
|
||||
f'got shape {weights.shape} for {nn} control points'
|
||||
)
|
||||
|
||||
with numpy.errstate(divide='ignore'):
|
||||
umul = (tt / (1 - tt)).clip(max=1)
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from functools import lru_cache
|
|||
|
||||
Key = TypeVar('Key')
|
||||
Value = TypeVar('Value')
|
||||
_MISSING = object()
|
||||
|
||||
|
||||
class DeferredDict(dict, Generic[Key, Value]):
|
||||
|
|
@ -46,6 +47,15 @@ class DeferredDict(dict, Generic[Key, Value]):
|
|||
return default
|
||||
return self[key]
|
||||
|
||||
def setdefault(self, key: Key, default: Value | Callable[[], Value] | None = None) -> Value | None:
|
||||
if key in self:
|
||||
return self[key]
|
||||
if callable(default):
|
||||
self[key] = default
|
||||
else:
|
||||
self.set_const(key, default)
|
||||
return self[key]
|
||||
|
||||
def items(self) -> Iterator[tuple[Key, Value]]:
|
||||
for key in self.keys():
|
||||
yield key, self[key]
|
||||
|
|
@ -65,6 +75,25 @@ class DeferredDict(dict, Generic[Key, Value]):
|
|||
else:
|
||||
self.set_const(k, v)
|
||||
|
||||
def pop(self, key: Key, default: Value | object = _MISSING) -> Value:
|
||||
if key in self:
|
||||
value = self[key]
|
||||
dict.pop(self, key)
|
||||
return value
|
||||
if default is _MISSING:
|
||||
raise KeyError(key)
|
||||
return default # type: ignore[return-value]
|
||||
|
||||
def popitem(self) -> tuple[Key, Value]:
|
||||
key, value_func = dict.popitem(self)
|
||||
return key, value_func()
|
||||
|
||||
def copy(self) -> 'DeferredDict[Key, Value]':
|
||||
new = DeferredDict[Key, Value]()
|
||||
for key in self.keys():
|
||||
dict.__setitem__(new, key, dict.__getitem__(self, key))
|
||||
return new
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return '<DeferredDict with keys ' + repr(set(self.keys())) + '>'
|
||||
|
||||
|
|
|
|||
|
|
@ -122,6 +122,7 @@ def data_to_ports(
|
|||
if not found_ports:
|
||||
return pattern
|
||||
|
||||
imported_ports: dict[str, Port] = {}
|
||||
for target, refs in pattern.refs.items():
|
||||
if target is None:
|
||||
continue
|
||||
|
|
@ -133,9 +134,14 @@ def data_to_ports(
|
|||
if not aa.ports:
|
||||
break
|
||||
|
||||
if ref.repetition is not None:
|
||||
logger.warning(f'Pattern {name if name else pattern} has repeated ref to {target!r}; '
|
||||
'data_to_ports() is importing only the base instance ports')
|
||||
aa.apply_ref_transform(ref)
|
||||
pattern.check_ports(other_names=aa.ports.keys())
|
||||
pattern.ports.update(aa.ports)
|
||||
Pattern(ports={**pattern.ports, **imported_ports}).check_ports(other_names=aa.ports.keys())
|
||||
imported_ports.update(aa.ports)
|
||||
|
||||
pattern.ports.update(imported_ports)
|
||||
return pattern
|
||||
|
||||
|
||||
|
|
@ -178,7 +184,14 @@ def data_to_ports_flat(
|
|||
name, property_string = label.string.split(':', 1)
|
||||
properties = property_string.split()
|
||||
ptype = properties[0] if len(properties) > 0 else 'unk'
|
||||
angle_deg = float(properties[1]) if len(properties) > 1 else numpy.inf
|
||||
if len(properties) > 1:
|
||||
try:
|
||||
angle_deg = float(properties[1])
|
||||
except ValueError:
|
||||
logger.warning(f'Invalid port label "{label.string}" in pattern "{pstr}" (bad angle)')
|
||||
continue
|
||||
else:
|
||||
angle_deg = numpy.inf
|
||||
|
||||
xy = label.offset
|
||||
angle = numpy.deg2rad(angle_deg) if numpy.isfinite(angle_deg) else None
|
||||
|
|
@ -190,4 +203,3 @@ def data_to_ports_flat(
|
|||
|
||||
pattern.ports.update(local_ports)
|
||||
return pattern
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue