From 75a91147098db84588932ae843e33f9b71c986ec Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Wed, 1 Apr 2026 21:16:03 -0700 Subject: [PATCH] [bezier] validate weights --- masque/test/test_utils.py | 16 ++++++++++++++++ masque/utils/curves.py | 7 +++++++ 2 files changed, 23 insertions(+) diff --git a/masque/test/test_utils.py b/masque/test/test_utils.py index 50b700b..0511a24 100644 --- a/masque/test/test_utils.py +++ b/masque/test/test_utils.py @@ -3,9 +3,12 @@ from pathlib import Path import numpy from numpy.testing import assert_equal, assert_allclose from numpy import pi +import pytest from ..utils import remove_duplicate_vertices, remove_colinear_vertices, poly_contains_points, rotation_matrix_2d, apply_transforms, DeferredDict from ..file.utils import tmpfile +from ..utils.curves import bezier +from ..error import PatternError def test_remove_duplicate_vertices() -> None: @@ -91,6 +94,19 @@ def test_apply_transforms_advanced() -> None: assert_allclose(combined[0], [0, 10, pi / 2, 1, 1], atol=1e-10) +def test_bezier_validates_weight_length() -> None: + with pytest.raises(PatternError, match='one entry per control point'): + bezier([[0, 0], [1, 1]], [0, 0.5, 1], weights=[1]) + + with pytest.raises(PatternError, match='one entry per control point'): + bezier([[0, 0], [1, 1]], [0, 0.5, 1], weights=[1, 2, 3]) + + +def test_bezier_accepts_exact_weight_count() -> None: + samples = bezier([[0, 0], [1, 1]], [0, 0.5, 1], weights=[1, 2]) + assert_allclose(samples, [[0, 0], [2 / 3, 2 / 3], [1, 1]], atol=1e-10) + + def test_deferred_dict_accessors_resolve_values_once() -> None: calls = 0 diff --git a/masque/utils/curves.py b/masque/utils/curves.py index 2348678..3a7671b 100644 --- a/masque/utils/curves.py +++ b/masque/utils/curves.py @@ -2,6 +2,8 @@ import numpy from numpy.typing import ArrayLike, NDArray from numpy import pi +from ..error import PatternError + try: from numpy import trapezoid except ImportError: @@ -31,6 +33,11 @@ def bezier( tt = numpy.asarray(tt) nn = nodes.shape[0] weights = numpy.ones(nn) if weights is None else numpy.asarray(weights) + if weights.ndim != 1 or weights.shape[0] != nn: + raise PatternError( + f'weights must be a 1D array with one entry per control point; ' + f'got shape {weights.shape} for {nn} control points' + ) with numpy.errstate(divide='ignore'): umul = (tt / (1 - tt)).clip(max=1)