From 4b416745daa0ab549572cbc851ce4b286fcf5393 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Wed, 1 Apr 2026 19:21:47 -0700 Subject: [PATCH] [repetition.Grid] check for invalid displacements or counts --- masque/repetition.py | 21 ++++++++++++++++++++- masque/test/test_repetition.py | 26 ++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/masque/repetition.py b/masque/repetition.py index 99e1082..1e12fcd 100644 --- a/masque/repetition.py +++ b/masque/repetition.py @@ -184,6 +184,8 @@ class Grid(Repetition): def a_count(self, val: int) -> None: if val != int(val): raise PatternError('a_count must be convertable to an int!') + if int(val) < 1: + raise PatternError(f'Repetition has too-small a_count: {val}') self._a_count = int(val) # b_count property @@ -195,6 +197,8 @@ class Grid(Repetition): def b_count(self, val: int) -> None: if val != int(val): raise PatternError('b_count must be convertable to an int!') + if int(val) < 1: + raise PatternError(f'Repetition has too-small b_count: {val}') self._b_count = int(val) @property @@ -325,7 +329,22 @@ class Arbitrary(Repetition): @displacements.setter def displacements(self, val: ArrayLike) -> None: - vala = numpy.array(val, dtype=float) + try: + vala = numpy.array(val, dtype=float) + except (TypeError, ValueError) as exc: + raise PatternError('displacements must be convertible to an Nx2 ndarray') from exc + + if vala.size == 0: + self._displacements = numpy.empty((0, 2), dtype=float) + return + + if vala.ndim == 1: + if vala.size != 2: + raise PatternError('displacements must be convertible to an Nx2 ndarray') + vala = vala.reshape(1, 2) + elif vala.ndim != 2 or vala.shape[1] != 2: + raise PatternError('displacements must be convertible to an Nx2 ndarray') + order = numpy.lexsort(vala.T[::-1]) # sortrows self._displacements = vala[order] diff --git a/masque/test/test_repetition.py b/masque/test/test_repetition.py index f423ab2..0d0be41 100644 --- a/masque/test/test_repetition.py +++ b/masque/test/test_repetition.py @@ -1,7 +1,9 @@ +import pytest from numpy.testing import assert_equal, assert_allclose from numpy import pi from ..repetition import Grid, Arbitrary +from ..error import PatternError def test_grid_displacements() -> None: @@ -51,6 +53,30 @@ def test_arbitrary_transform() -> None: assert_allclose(arb.displacements, [[0, -10]], atol=1e-10) +def test_arbitrary_empty_repetition_is_allowed() -> None: + arb = Arbitrary([]) + assert arb.displacements.shape == (0, 2) + assert arb.get_bounds() is None + + +def test_arbitrary_rejects_non_nx2_displacements() -> None: + for displacements in ([[1], [2]], [[1, 2, 3]], [1, 2, 3]): + with pytest.raises(PatternError, match='displacements must be convertible to an Nx2 ndarray'): + Arbitrary(displacements) + + +def test_grid_count_setters_reject_nonpositive_values() -> None: + for attr, value, message in ( + ('a_count', 0, 'a_count'), + ('a_count', -1, 'a_count'), + ('b_count', 0, 'b_count'), + ('b_count', -1, 'b_count'), + ): + grid = Grid(a_vector=(10, 0), b_vector=(0, 5), a_count=2, b_count=2) + with pytest.raises(PatternError, match=message): + setattr(grid, attr, value) + + def test_repetition_less_equal_includes_equality() -> None: grid_a = Grid(a_vector=(10, 0), a_count=2) grid_b = Grid(a_vector=(10, 0), a_count=2)