improve arg checking

This commit is contained in:
Jan Petykiewicz 2026-04-20 10:47:35 -07:00
commit 15c2cf8351
3 changed files with 90 additions and 40 deletions

View file

@ -95,6 +95,8 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
`GridError` on invalid input
"""
edge_arrs = [numpy.array(cc) for cc in pixel_edge_coordinates]
if len(edge_arrs) != 3:
raise GridError('pixel_edge_coordinates must contain exactly 3 coordinate arrays')
self.exyz = [numpy.unique(edges) for edges in edge_arrs]
self.shifts = numpy.array(shifts, dtype=float)
@ -106,6 +108,8 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
self.periodic = [periodic] * 3
else:
self.periodic = list(periodic)
if len(self.periodic) != 3:
raise GridError('periodic must be a bool or a sequence of length 3')
if len(self.shifts.shape) != 2:
raise GridError('Misshapen shifts: shifts must have two axes! '

View file

@ -2,7 +2,7 @@ import pytest
import numpy
from numpy.testing import assert_allclose #, assert_array_equal
from .. import Grid, Extent, GridError, Plane
from .. import Grid, Extent, GridError, Plane, Slab
def test_draw_oncenter_2x2() -> None:
@ -194,3 +194,35 @@ def test_sampled_visualization_helpers_do_not_error() -> None:
pyplot.close(fig_slice)
pyplot.close(fig_edges)
def test_grid_constructor_rejects_invalid_coordinate_count() -> None:
with pytest.raises(GridError):
Grid([[0, 1], [0, 1]], shifts=[[0, 0, 0]])
with pytest.raises(GridError):
Grid([[0, 1], [0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]])
def test_grid_constructor_rejects_invalid_periodic_length() -> None:
with pytest.raises(GridError):
Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]], periodic=[True, False])
def test_extent_and_slab_reject_inverted_geometry() -> None:
with pytest.raises(GridError):
Extent(center=0, min=1)
with pytest.raises(GridError):
Extent(min=2, max=1)
with pytest.raises(GridError):
Slab(axis='z', center=1, max=0)
def test_extent_accepts_scalar_like_inputs() -> None:
extent = Extent(min=numpy.array([1.0]), span=numpy.array([4.0]))
assert_allclose([extent.center, extent.span, extent.min, extent.max], [3.0, 4.0, 1.0, 5.0])

View file

@ -1,12 +1,25 @@
from typing import Protocol, TypedDict, runtime_checkable, cast
from dataclasses import dataclass
import numpy
class GridError(Exception):
""" Base error type for `gridlock` """
pass
def _coerce_scalar(name: str, value: object) -> float:
arr = numpy.asarray(value)
if arr.size != 1:
raise GridError(f'{name} must be a scalar value')
try:
return float(arr.reshape(()))
except (TypeError, ValueError) as exc:
raise GridError(f'{name} must be a real scalar value') from exc
class ExtentDict(TypedDict, total=False):
"""
Geometrical definition of an extent (1D bounded region)
@ -58,44 +71,46 @@ class Extent(ExtentProtocol):
max: float | None = None,
span: float | None = None,
) -> None:
if sum(cc is None for cc in (min, center, max, span)) != 2:
raise GridError('Exactly two of min, center, max, span must be None!')
values = {
'min': None if min is None else _coerce_scalar('min', min),
'center': None if center is None else _coerce_scalar('center', center),
'max': None if max is None else _coerce_scalar('max', max),
'span': None if span is None else _coerce_scalar('span', span),
}
if sum(value is not None for value in values.values()) != 2:
raise GridError('Exactly two of min, center, max, span must be provided')
if span is None:
if center is None:
assert min is not None
assert max is not None
assert max >= min
center = 0.5 * (max + min)
span = max - min
elif max is None:
assert min is not None
assert center is not None
span = 2 * (center - min)
elif min is None:
assert center is not None
assert max is not None
span = 2 * (max - center)
else: # noqa: PLR5501
if center is not None:
pass
elif max is None:
assert min is not None
assert span is not None
center = min + 0.5 * span
elif min is None:
assert max is not None
assert span is not None
center = max - 0.5 * span
min_v = values['min']
center_v = values['center']
max_v = values['max']
span_v = values['span']
assert center is not None
assert span is not None
if hasattr(center, '__len__'):
assert len(center) == 1
if hasattr(span, '__len__'):
assert len(span) == 1
self.center = center
self.span = span
if span_v is not None and span_v < 0:
raise GridError('span must be non-negative')
if min_v is not None and max_v is not None:
if max_v < min_v:
raise GridError('max must be greater than or equal to min')
center_v = 0.5 * (max_v + min_v)
span_v = max_v - min_v
elif center_v is not None and min_v is not None:
span_v = 2 * (center_v - min_v)
if span_v < 0:
raise GridError('min must be less than or equal to center')
elif center_v is not None and max_v is not None:
span_v = 2 * (max_v - center_v)
if span_v < 0:
raise GridError('center must be less than or equal to max')
elif min_v is not None and span_v is not None:
center_v = min_v + 0.5 * span_v
elif max_v is not None and span_v is not None:
center_v = max_v - 0.5 * span_v
if center_v is None or span_v is None:
raise GridError('Unable to construct extent from the provided values')
self.center = center_v
self.span = span_v
class SlabDict(TypedDict, total=False):
@ -231,4 +246,3 @@ class Plane(PlaneProtocol):
if hasattr(cpos, '__len__'):
assert len(cpos) == 1
self.pos = cpos