diff --git a/masque/repetition.py b/masque/repetition.py index a8de94c..99e1082 100644 --- a/masque/repetition.py +++ b/masque/repetition.py @@ -34,7 +34,7 @@ class Repetition(Copyable, Rotatable, Mirrorable, Scalable, Bounded, metaclass=A pass @abstractmethod - def __le__(self, other: 'Repetition') -> bool: + def __lt__(self, other: 'Repetition') -> bool: pass @abstractmethod @@ -288,7 +288,7 @@ class Grid(Repetition): return False return True - def __le__(self, other: Repetition) -> bool: + def __lt__(self, other: Repetition) -> bool: if type(self) is not type(other): return repr(type(self)) < repr(type(other)) other = cast('Grid', other) @@ -347,7 +347,7 @@ class Arbitrary(Repetition): return False return numpy.array_equal(self.displacements, other.displacements) - def __le__(self, other: Repetition) -> bool: + def __lt__(self, other: Repetition) -> bool: if type(self) is not type(other): return repr(type(self)) < repr(type(other)) other = cast('Arbitrary', other) @@ -415,4 +415,3 @@ class Arbitrary(Repetition): """ self.displacements = self.displacements * c return self - diff --git a/masque/test/test_repetition.py b/masque/test/test_repetition.py index 5ef2fa9..f423ab2 100644 --- a/masque/test/test_repetition.py +++ b/masque/test/test_repetition.py @@ -49,3 +49,17 @@ def test_arbitrary_transform() -> None: # self.displacements[:, 1 - axis] *= -1 # if axis=0, 1-axis=1, so y *= -1 assert_allclose(arb.displacements, [[0, -10]], atol=1e-10) + + +def test_repetition_less_equal_includes_equality() -> None: + grid_a = Grid(a_vector=(10, 0), a_count=2) + grid_b = Grid(a_vector=(10, 0), a_count=2) + assert grid_a == grid_b + assert grid_a <= grid_b + assert grid_a >= grid_b + + arb_a = Arbitrary([[0, 0], [1, 0]]) + arb_b = Arbitrary([[0, 0], [1, 0]]) + assert arb_a == arb_b + assert arb_a <= arb_b + assert arb_a >= arb_b