improve arg checking
This commit is contained in:
parent
526b9e1666
commit
15c2cf8351
3 changed files with 90 additions and 40 deletions
|
|
@ -95,6 +95,8 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
|
|||
`GridError` on invalid input
|
||||
"""
|
||||
edge_arrs = [numpy.array(cc) for cc in pixel_edge_coordinates]
|
||||
if len(edge_arrs) != 3:
|
||||
raise GridError('pixel_edge_coordinates must contain exactly 3 coordinate arrays')
|
||||
self.exyz = [numpy.unique(edges) for edges in edge_arrs]
|
||||
self.shifts = numpy.array(shifts, dtype=float)
|
||||
|
||||
|
|
@ -106,6 +108,8 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
|
|||
self.periodic = [periodic] * 3
|
||||
else:
|
||||
self.periodic = list(periodic)
|
||||
if len(self.periodic) != 3:
|
||||
raise GridError('periodic must be a bool or a sequence of length 3')
|
||||
|
||||
if len(self.shifts.shape) != 2:
|
||||
raise GridError('Misshapen shifts: shifts must have two axes! '
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import pytest
|
|||
import numpy
|
||||
from numpy.testing import assert_allclose #, assert_array_equal
|
||||
|
||||
from .. import Grid, Extent, GridError, Plane
|
||||
from .. import Grid, Extent, GridError, Plane, Slab
|
||||
|
||||
|
||||
def test_draw_oncenter_2x2() -> None:
|
||||
|
|
@ -194,3 +194,35 @@ def test_sampled_visualization_helpers_do_not_error() -> None:
|
|||
|
||||
pyplot.close(fig_slice)
|
||||
pyplot.close(fig_edges)
|
||||
|
||||
|
||||
def test_grid_constructor_rejects_invalid_coordinate_count() -> None:
|
||||
with pytest.raises(GridError):
|
||||
Grid([[0, 1], [0, 1]], shifts=[[0, 0, 0]])
|
||||
|
||||
with pytest.raises(GridError):
|
||||
Grid([[0, 1], [0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]])
|
||||
|
||||
|
||||
def test_grid_constructor_rejects_invalid_periodic_length() -> None:
|
||||
with pytest.raises(GridError):
|
||||
Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]], periodic=[True, False])
|
||||
|
||||
|
||||
def test_extent_and_slab_reject_inverted_geometry() -> None:
|
||||
with pytest.raises(GridError):
|
||||
Extent(center=0, min=1)
|
||||
|
||||
with pytest.raises(GridError):
|
||||
Extent(min=2, max=1)
|
||||
|
||||
with pytest.raises(GridError):
|
||||
Slab(axis='z', center=1, max=0)
|
||||
|
||||
|
||||
def test_extent_accepts_scalar_like_inputs() -> None:
|
||||
extent = Extent(min=numpy.array([1.0]), span=numpy.array([4.0]))
|
||||
|
||||
assert_allclose([extent.center, extent.span, extent.min, extent.max], [3.0, 4.0, 1.0, 5.0])
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,25 @@
|
|||
from typing import Protocol, TypedDict, runtime_checkable, cast
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy
|
||||
|
||||
|
||||
class GridError(Exception):
|
||||
""" Base error type for `gridlock` """
|
||||
pass
|
||||
|
||||
|
||||
def _coerce_scalar(name: str, value: object) -> float:
|
||||
arr = numpy.asarray(value)
|
||||
if arr.size != 1:
|
||||
raise GridError(f'{name} must be a scalar value')
|
||||
|
||||
try:
|
||||
return float(arr.reshape(()))
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise GridError(f'{name} must be a real scalar value') from exc
|
||||
|
||||
|
||||
class ExtentDict(TypedDict, total=False):
|
||||
"""
|
||||
Geometrical definition of an extent (1D bounded region)
|
||||
|
|
@ -58,44 +71,46 @@ class Extent(ExtentProtocol):
|
|||
max: float | None = None,
|
||||
span: float | None = None,
|
||||
) -> None:
|
||||
if sum(cc is None for cc in (min, center, max, span)) != 2:
|
||||
raise GridError('Exactly two of min, center, max, span must be None!')
|
||||
values = {
|
||||
'min': None if min is None else _coerce_scalar('min', min),
|
||||
'center': None if center is None else _coerce_scalar('center', center),
|
||||
'max': None if max is None else _coerce_scalar('max', max),
|
||||
'span': None if span is None else _coerce_scalar('span', span),
|
||||
}
|
||||
if sum(value is not None for value in values.values()) != 2:
|
||||
raise GridError('Exactly two of min, center, max, span must be provided')
|
||||
|
||||
if span is None:
|
||||
if center is None:
|
||||
assert min is not None
|
||||
assert max is not None
|
||||
assert max >= min
|
||||
center = 0.5 * (max + min)
|
||||
span = max - min
|
||||
elif max is None:
|
||||
assert min is not None
|
||||
assert center is not None
|
||||
span = 2 * (center - min)
|
||||
elif min is None:
|
||||
assert center is not None
|
||||
assert max is not None
|
||||
span = 2 * (max - center)
|
||||
else: # noqa: PLR5501
|
||||
if center is not None:
|
||||
pass
|
||||
elif max is None:
|
||||
assert min is not None
|
||||
assert span is not None
|
||||
center = min + 0.5 * span
|
||||
elif min is None:
|
||||
assert max is not None
|
||||
assert span is not None
|
||||
center = max - 0.5 * span
|
||||
min_v = values['min']
|
||||
center_v = values['center']
|
||||
max_v = values['max']
|
||||
span_v = values['span']
|
||||
|
||||
assert center is not None
|
||||
assert span is not None
|
||||
if hasattr(center, '__len__'):
|
||||
assert len(center) == 1
|
||||
if hasattr(span, '__len__'):
|
||||
assert len(span) == 1
|
||||
self.center = center
|
||||
self.span = span
|
||||
if span_v is not None and span_v < 0:
|
||||
raise GridError('span must be non-negative')
|
||||
|
||||
if min_v is not None and max_v is not None:
|
||||
if max_v < min_v:
|
||||
raise GridError('max must be greater than or equal to min')
|
||||
center_v = 0.5 * (max_v + min_v)
|
||||
span_v = max_v - min_v
|
||||
elif center_v is not None and min_v is not None:
|
||||
span_v = 2 * (center_v - min_v)
|
||||
if span_v < 0:
|
||||
raise GridError('min must be less than or equal to center')
|
||||
elif center_v is not None and max_v is not None:
|
||||
span_v = 2 * (max_v - center_v)
|
||||
if span_v < 0:
|
||||
raise GridError('center must be less than or equal to max')
|
||||
elif min_v is not None and span_v is not None:
|
||||
center_v = min_v + 0.5 * span_v
|
||||
elif max_v is not None and span_v is not None:
|
||||
center_v = max_v - 0.5 * span_v
|
||||
|
||||
if center_v is None or span_v is None:
|
||||
raise GridError('Unable to construct extent from the provided values')
|
||||
|
||||
self.center = center_v
|
||||
self.span = span_v
|
||||
|
||||
|
||||
class SlabDict(TypedDict, total=False):
|
||||
|
|
@ -231,4 +246,3 @@ class Plane(PlaneProtocol):
|
|||
if hasattr(cpos, '__len__'):
|
||||
assert len(cpos) == 1
|
||||
self.pos = cpos
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue