From 15c2cf83516a8fe9bce4a2a0603398f2bded0dcc Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 20 Apr 2026 10:47:35 -0700 Subject: [PATCH] improve arg checking --- gridlock/grid.py | 4 ++ gridlock/test/test_grid.py | 34 ++++++++++++++- gridlock/utils.py | 88 ++++++++++++++++++++++---------------- 3 files changed, 88 insertions(+), 38 deletions(-) diff --git a/gridlock/grid.py b/gridlock/grid.py index 5790dbd..5bed422 100644 --- a/gridlock/grid.py +++ b/gridlock/grid.py @@ -95,6 +95,8 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): `GridError` on invalid input """ edge_arrs = [numpy.array(cc) for cc in pixel_edge_coordinates] + if len(edge_arrs) != 3: + raise GridError('pixel_edge_coordinates must contain exactly 3 coordinate arrays') self.exyz = [numpy.unique(edges) for edges in edge_arrs] self.shifts = numpy.array(shifts, dtype=float) @@ -106,6 +108,8 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): self.periodic = [periodic] * 3 else: self.periodic = list(periodic) + if len(self.periodic) != 3: + raise GridError('periodic must be a bool or a sequence of length 3') if len(self.shifts.shape) != 2: raise GridError('Misshapen shifts: shifts must have two axes! ' diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index 84b0f7b..60929e8 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -2,7 +2,7 @@ import pytest import numpy from numpy.testing import assert_allclose #, assert_array_equal -from .. import Grid, Extent, GridError, Plane +from .. import Grid, Extent, GridError, Plane, Slab def test_draw_oncenter_2x2() -> None: @@ -194,3 +194,35 @@ def test_sampled_visualization_helpers_do_not_error() -> None: pyplot.close(fig_slice) pyplot.close(fig_edges) + + +def test_grid_constructor_rejects_invalid_coordinate_count() -> None: + with pytest.raises(GridError): + Grid([[0, 1], [0, 1]], shifts=[[0, 0, 0]]) + + with pytest.raises(GridError): + Grid([[0, 1], [0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]) + + +def test_grid_constructor_rejects_invalid_periodic_length() -> None: + with pytest.raises(GridError): + Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]], periodic=[True, False]) + + +def test_extent_and_slab_reject_inverted_geometry() -> None: + with pytest.raises(GridError): + Extent(center=0, min=1) + + with pytest.raises(GridError): + Extent(min=2, max=1) + + with pytest.raises(GridError): + Slab(axis='z', center=1, max=0) + + +def test_extent_accepts_scalar_like_inputs() -> None: + extent = Extent(min=numpy.array([1.0]), span=numpy.array([4.0])) + + assert_allclose([extent.center, extent.span, extent.min, extent.max], [3.0, 4.0, 1.0, 5.0]) + + diff --git a/gridlock/utils.py b/gridlock/utils.py index 8a8f11d..585b999 100644 --- a/gridlock/utils.py +++ b/gridlock/utils.py @@ -1,12 +1,25 @@ from typing import Protocol, TypedDict, runtime_checkable, cast from dataclasses import dataclass +import numpy + class GridError(Exception): """ Base error type for `gridlock` """ pass +def _coerce_scalar(name: str, value: object) -> float: + arr = numpy.asarray(value) + if arr.size != 1: + raise GridError(f'{name} must be a scalar value') + + try: + return float(arr.reshape(())) + except (TypeError, ValueError) as exc: + raise GridError(f'{name} must be a real scalar value') from exc + + class ExtentDict(TypedDict, total=False): """ Geometrical definition of an extent (1D bounded region) @@ -58,44 +71,46 @@ class Extent(ExtentProtocol): max: float | None = None, span: float | None = None, ) -> None: - if sum(cc is None for cc in (min, center, max, span)) != 2: - raise GridError('Exactly two of min, center, max, span must be None!') + values = { + 'min': None if min is None else _coerce_scalar('min', min), + 'center': None if center is None else _coerce_scalar('center', center), + 'max': None if max is None else _coerce_scalar('max', max), + 'span': None if span is None else _coerce_scalar('span', span), + } + if sum(value is not None for value in values.values()) != 2: + raise GridError('Exactly two of min, center, max, span must be provided') - if span is None: - if center is None: - assert min is not None - assert max is not None - assert max >= min - center = 0.5 * (max + min) - span = max - min - elif max is None: - assert min is not None - assert center is not None - span = 2 * (center - min) - elif min is None: - assert center is not None - assert max is not None - span = 2 * (max - center) - else: # noqa: PLR5501 - if center is not None: - pass - elif max is None: - assert min is not None - assert span is not None - center = min + 0.5 * span - elif min is None: - assert max is not None - assert span is not None - center = max - 0.5 * span + min_v = values['min'] + center_v = values['center'] + max_v = values['max'] + span_v = values['span'] - assert center is not None - assert span is not None - if hasattr(center, '__len__'): - assert len(center) == 1 - if hasattr(span, '__len__'): - assert len(span) == 1 - self.center = center - self.span = span + if span_v is not None and span_v < 0: + raise GridError('span must be non-negative') + + if min_v is not None and max_v is not None: + if max_v < min_v: + raise GridError('max must be greater than or equal to min') + center_v = 0.5 * (max_v + min_v) + span_v = max_v - min_v + elif center_v is not None and min_v is not None: + span_v = 2 * (center_v - min_v) + if span_v < 0: + raise GridError('min must be less than or equal to center') + elif center_v is not None and max_v is not None: + span_v = 2 * (max_v - center_v) + if span_v < 0: + raise GridError('center must be less than or equal to max') + elif min_v is not None and span_v is not None: + center_v = min_v + 0.5 * span_v + elif max_v is not None and span_v is not None: + center_v = max_v - 0.5 * span_v + + if center_v is None or span_v is None: + raise GridError('Unable to construct extent from the provided values') + + self.center = center_v + self.span = span_v class SlabDict(TypedDict, total=False): @@ -231,4 +246,3 @@ class Plane(PlaneProtocol): if hasattr(cpos, '__len__'): assert len(cpos) == 1 self.pos = cpos -