[repetition.Grid] check for invalid displacements or counts

This commit is contained in:
Jan Petykiewicz 2026-04-01 19:21:47 -07:00
commit 4b416745da
2 changed files with 46 additions and 1 deletions

View file

@ -184,6 +184,8 @@ class Grid(Repetition):
def a_count(self, val: int) -> None:
if val != int(val):
raise PatternError('a_count must be convertable to an int!')
if int(val) < 1:
raise PatternError(f'Repetition has too-small a_count: {val}')
self._a_count = int(val)
# b_count property
@ -195,6 +197,8 @@ class Grid(Repetition):
def b_count(self, val: int) -> None:
if val != int(val):
raise PatternError('b_count must be convertable to an int!')
if int(val) < 1:
raise PatternError(f'Repetition has too-small b_count: {val}')
self._b_count = int(val)
@property
@ -325,7 +329,22 @@ class Arbitrary(Repetition):
@displacements.setter
def displacements(self, val: ArrayLike) -> None:
vala = numpy.array(val, dtype=float)
try:
vala = numpy.array(val, dtype=float)
except (TypeError, ValueError) as exc:
raise PatternError('displacements must be convertible to an Nx2 ndarray') from exc
if vala.size == 0:
self._displacements = numpy.empty((0, 2), dtype=float)
return
if vala.ndim == 1:
if vala.size != 2:
raise PatternError('displacements must be convertible to an Nx2 ndarray')
vala = vala.reshape(1, 2)
elif vala.ndim != 2 or vala.shape[1] != 2:
raise PatternError('displacements must be convertible to an Nx2 ndarray')
order = numpy.lexsort(vala.T[::-1]) # sortrows
self._displacements = vala[order]

View file

@ -1,7 +1,9 @@
import pytest
from numpy.testing import assert_equal, assert_allclose
from numpy import pi
from ..repetition import Grid, Arbitrary
from ..error import PatternError
def test_grid_displacements() -> None:
@ -51,6 +53,30 @@ def test_arbitrary_transform() -> None:
assert_allclose(arb.displacements, [[0, -10]], atol=1e-10)
def test_arbitrary_empty_repetition_is_allowed() -> None:
arb = Arbitrary([])
assert arb.displacements.shape == (0, 2)
assert arb.get_bounds() is None
def test_arbitrary_rejects_non_nx2_displacements() -> None:
for displacements in ([[1], [2]], [[1, 2, 3]], [1, 2, 3]):
with pytest.raises(PatternError, match='displacements must be convertible to an Nx2 ndarray'):
Arbitrary(displacements)
def test_grid_count_setters_reject_nonpositive_values() -> None:
for attr, value, message in (
('a_count', 0, 'a_count'),
('a_count', -1, 'a_count'),
('b_count', 0, 'b_count'),
('b_count', -1, 'b_count'),
):
grid = Grid(a_vector=(10, 0), b_vector=(0, 5), a_count=2, b_count=2)
with pytest.raises(PatternError, match=message):
setattr(grid, attr, value)
def test_repetition_less_equal_includes_equality() -> None:
grid_a = Grid(a_vector=(10, 0), a_count=2)
grid_b = Grid(a_vector=(10, 0), a_count=2)