[repetition.Grid] check for invalid displacements or counts
This commit is contained in:
parent
0830dce50c
commit
4b416745da
2 changed files with 46 additions and 1 deletions
|
|
@ -184,6 +184,8 @@ class Grid(Repetition):
|
|||
def a_count(self, val: int) -> None:
|
||||
if val != int(val):
|
||||
raise PatternError('a_count must be convertable to an int!')
|
||||
if int(val) < 1:
|
||||
raise PatternError(f'Repetition has too-small a_count: {val}')
|
||||
self._a_count = int(val)
|
||||
|
||||
# b_count property
|
||||
|
|
@ -195,6 +197,8 @@ class Grid(Repetition):
|
|||
def b_count(self, val: int) -> None:
|
||||
if val != int(val):
|
||||
raise PatternError('b_count must be convertable to an int!')
|
||||
if int(val) < 1:
|
||||
raise PatternError(f'Repetition has too-small b_count: {val}')
|
||||
self._b_count = int(val)
|
||||
|
||||
@property
|
||||
|
|
@ -325,7 +329,22 @@ class Arbitrary(Repetition):
|
|||
|
||||
@displacements.setter
|
||||
def displacements(self, val: ArrayLike) -> None:
|
||||
vala = numpy.array(val, dtype=float)
|
||||
try:
|
||||
vala = numpy.array(val, dtype=float)
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise PatternError('displacements must be convertible to an Nx2 ndarray') from exc
|
||||
|
||||
if vala.size == 0:
|
||||
self._displacements = numpy.empty((0, 2), dtype=float)
|
||||
return
|
||||
|
||||
if vala.ndim == 1:
|
||||
if vala.size != 2:
|
||||
raise PatternError('displacements must be convertible to an Nx2 ndarray')
|
||||
vala = vala.reshape(1, 2)
|
||||
elif vala.ndim != 2 or vala.shape[1] != 2:
|
||||
raise PatternError('displacements must be convertible to an Nx2 ndarray')
|
||||
|
||||
order = numpy.lexsort(vala.T[::-1]) # sortrows
|
||||
self._displacements = vala[order]
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
import pytest
|
||||
from numpy.testing import assert_equal, assert_allclose
|
||||
from numpy import pi
|
||||
|
||||
from ..repetition import Grid, Arbitrary
|
||||
from ..error import PatternError
|
||||
|
||||
|
||||
def test_grid_displacements() -> None:
|
||||
|
|
@ -51,6 +53,30 @@ def test_arbitrary_transform() -> None:
|
|||
assert_allclose(arb.displacements, [[0, -10]], atol=1e-10)
|
||||
|
||||
|
||||
def test_arbitrary_empty_repetition_is_allowed() -> None:
|
||||
arb = Arbitrary([])
|
||||
assert arb.displacements.shape == (0, 2)
|
||||
assert arb.get_bounds() is None
|
||||
|
||||
|
||||
def test_arbitrary_rejects_non_nx2_displacements() -> None:
|
||||
for displacements in ([[1], [2]], [[1, 2, 3]], [1, 2, 3]):
|
||||
with pytest.raises(PatternError, match='displacements must be convertible to an Nx2 ndarray'):
|
||||
Arbitrary(displacements)
|
||||
|
||||
|
||||
def test_grid_count_setters_reject_nonpositive_values() -> None:
|
||||
for attr, value, message in (
|
||||
('a_count', 0, 'a_count'),
|
||||
('a_count', -1, 'a_count'),
|
||||
('b_count', 0, 'b_count'),
|
||||
('b_count', -1, 'b_count'),
|
||||
):
|
||||
grid = Grid(a_vector=(10, 0), b_vector=(0, 5), a_count=2, b_count=2)
|
||||
with pytest.raises(PatternError, match=message):
|
||||
setattr(grid, attr, value)
|
||||
|
||||
|
||||
def test_repetition_less_equal_includes_equality() -> None:
|
||||
grid_a = Grid(a_vector=(10, 0), a_count=2)
|
||||
grid_b = Grid(a_vector=(10, 0), a_count=2)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue