diff --git a/masque/label.py b/masque/label.py index b662035..3dbbc08 100644 --- a/masque/label.py +++ b/masque/label.py @@ -78,6 +78,8 @@ class Label(PositionableImpl, RepeatableImpl, AnnotatableImpl, Bounded, Pivotabl return annotations_lt(self.annotations, other.annotations) def __eq__(self, other: Any) -> bool: + if type(self) is not type(other): + return False return ( self.string == other.string and numpy.array_equal(self.offset, other.offset) diff --git a/masque/ref.py b/masque/ref.py index a40776a..b012365 100644 --- a/masque/ref.py +++ b/masque/ref.py @@ -122,6 +122,8 @@ class Ref( return annotations_lt(self.annotations, other.annotations) def __eq__(self, other: Any) -> bool: + if type(self) is not type(other): + return False return ( numpy.array_equal(self.offset, other.offset) and self.mirrored == other.mirrored diff --git a/masque/test/test_label.py b/masque/test/test_label.py index ad8c08b..f4f364b 100644 --- a/masque/test/test_label.py +++ b/masque/test/test_label.py @@ -46,3 +46,9 @@ def test_label_copy() -> None: assert l1 is not l2 l2.offset[0] = 100 assert l1.offset[0] == 1 + + +def test_label_eq_unrelated_objects_is_false() -> None: + lbl = Label("test") + assert not (lbl == None) + assert not (lbl == object()) diff --git a/masque/test/test_ref.py b/masque/test/test_ref.py index c1dbf26..d3e9778 100644 --- a/masque/test/test_ref.py +++ b/masque/test/test_ref.py @@ -87,3 +87,9 @@ def test_ref_scale_by_rejects_nonpositive_scale() -> None: with pytest.raises(MasqueError, match='Scale must be positive'): ref.scale_by(-1) + + +def test_ref_eq_unrelated_objects_is_false() -> None: + ref = Ref() + assert not (ref == None) + assert not (ref == object())