improve arg checking
This commit is contained in:
parent
526b9e1666
commit
15c2cf8351
3 changed files with 90 additions and 40 deletions
|
|
@ -95,6 +95,8 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
|
||||||
`GridError` on invalid input
|
`GridError` on invalid input
|
||||||
"""
|
"""
|
||||||
edge_arrs = [numpy.array(cc) for cc in pixel_edge_coordinates]
|
edge_arrs = [numpy.array(cc) for cc in pixel_edge_coordinates]
|
||||||
|
if len(edge_arrs) != 3:
|
||||||
|
raise GridError('pixel_edge_coordinates must contain exactly 3 coordinate arrays')
|
||||||
self.exyz = [numpy.unique(edges) for edges in edge_arrs]
|
self.exyz = [numpy.unique(edges) for edges in edge_arrs]
|
||||||
self.shifts = numpy.array(shifts, dtype=float)
|
self.shifts = numpy.array(shifts, dtype=float)
|
||||||
|
|
||||||
|
|
@ -106,6 +108,8 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
|
||||||
self.periodic = [periodic] * 3
|
self.periodic = [periodic] * 3
|
||||||
else:
|
else:
|
||||||
self.periodic = list(periodic)
|
self.periodic = list(periodic)
|
||||||
|
if len(self.periodic) != 3:
|
||||||
|
raise GridError('periodic must be a bool or a sequence of length 3')
|
||||||
|
|
||||||
if len(self.shifts.shape) != 2:
|
if len(self.shifts.shape) != 2:
|
||||||
raise GridError('Misshapen shifts: shifts must have two axes! '
|
raise GridError('Misshapen shifts: shifts must have two axes! '
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import pytest
|
||||||
import numpy
|
import numpy
|
||||||
from numpy.testing import assert_allclose #, assert_array_equal
|
from numpy.testing import assert_allclose #, assert_array_equal
|
||||||
|
|
||||||
from .. import Grid, Extent, GridError, Plane
|
from .. import Grid, Extent, GridError, Plane, Slab
|
||||||
|
|
||||||
|
|
||||||
def test_draw_oncenter_2x2() -> None:
|
def test_draw_oncenter_2x2() -> None:
|
||||||
|
|
@ -194,3 +194,35 @@ def test_sampled_visualization_helpers_do_not_error() -> None:
|
||||||
|
|
||||||
pyplot.close(fig_slice)
|
pyplot.close(fig_slice)
|
||||||
pyplot.close(fig_edges)
|
pyplot.close(fig_edges)
|
||||||
|
|
||||||
|
|
||||||
|
def test_grid_constructor_rejects_invalid_coordinate_count() -> None:
|
||||||
|
with pytest.raises(GridError):
|
||||||
|
Grid([[0, 1], [0, 1]], shifts=[[0, 0, 0]])
|
||||||
|
|
||||||
|
with pytest.raises(GridError):
|
||||||
|
Grid([[0, 1], [0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]])
|
||||||
|
|
||||||
|
|
||||||
|
def test_grid_constructor_rejects_invalid_periodic_length() -> None:
|
||||||
|
with pytest.raises(GridError):
|
||||||
|
Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]], periodic=[True, False])
|
||||||
|
|
||||||
|
|
||||||
|
def test_extent_and_slab_reject_inverted_geometry() -> None:
|
||||||
|
with pytest.raises(GridError):
|
||||||
|
Extent(center=0, min=1)
|
||||||
|
|
||||||
|
with pytest.raises(GridError):
|
||||||
|
Extent(min=2, max=1)
|
||||||
|
|
||||||
|
with pytest.raises(GridError):
|
||||||
|
Slab(axis='z', center=1, max=0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_extent_accepts_scalar_like_inputs() -> None:
|
||||||
|
extent = Extent(min=numpy.array([1.0]), span=numpy.array([4.0]))
|
||||||
|
|
||||||
|
assert_allclose([extent.center, extent.span, extent.min, extent.max], [3.0, 4.0, 1.0, 5.0])
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,25 @@
|
||||||
from typing import Protocol, TypedDict, runtime_checkable, cast
|
from typing import Protocol, TypedDict, runtime_checkable, cast
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import numpy
|
||||||
|
|
||||||
|
|
||||||
class GridError(Exception):
|
class GridError(Exception):
|
||||||
""" Base error type for `gridlock` """
|
""" Base error type for `gridlock` """
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_scalar(name: str, value: object) -> float:
|
||||||
|
arr = numpy.asarray(value)
|
||||||
|
if arr.size != 1:
|
||||||
|
raise GridError(f'{name} must be a scalar value')
|
||||||
|
|
||||||
|
try:
|
||||||
|
return float(arr.reshape(()))
|
||||||
|
except (TypeError, ValueError) as exc:
|
||||||
|
raise GridError(f'{name} must be a real scalar value') from exc
|
||||||
|
|
||||||
|
|
||||||
class ExtentDict(TypedDict, total=False):
|
class ExtentDict(TypedDict, total=False):
|
||||||
"""
|
"""
|
||||||
Geometrical definition of an extent (1D bounded region)
|
Geometrical definition of an extent (1D bounded region)
|
||||||
|
|
@ -58,44 +71,46 @@ class Extent(ExtentProtocol):
|
||||||
max: float | None = None,
|
max: float | None = None,
|
||||||
span: float | None = None,
|
span: float | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if sum(cc is None for cc in (min, center, max, span)) != 2:
|
values = {
|
||||||
raise GridError('Exactly two of min, center, max, span must be None!')
|
'min': None if min is None else _coerce_scalar('min', min),
|
||||||
|
'center': None if center is None else _coerce_scalar('center', center),
|
||||||
|
'max': None if max is None else _coerce_scalar('max', max),
|
||||||
|
'span': None if span is None else _coerce_scalar('span', span),
|
||||||
|
}
|
||||||
|
if sum(value is not None for value in values.values()) != 2:
|
||||||
|
raise GridError('Exactly two of min, center, max, span must be provided')
|
||||||
|
|
||||||
if span is None:
|
min_v = values['min']
|
||||||
if center is None:
|
center_v = values['center']
|
||||||
assert min is not None
|
max_v = values['max']
|
||||||
assert max is not None
|
span_v = values['span']
|
||||||
assert max >= min
|
|
||||||
center = 0.5 * (max + min)
|
|
||||||
span = max - min
|
|
||||||
elif max is None:
|
|
||||||
assert min is not None
|
|
||||||
assert center is not None
|
|
||||||
span = 2 * (center - min)
|
|
||||||
elif min is None:
|
|
||||||
assert center is not None
|
|
||||||
assert max is not None
|
|
||||||
span = 2 * (max - center)
|
|
||||||
else: # noqa: PLR5501
|
|
||||||
if center is not None:
|
|
||||||
pass
|
|
||||||
elif max is None:
|
|
||||||
assert min is not None
|
|
||||||
assert span is not None
|
|
||||||
center = min + 0.5 * span
|
|
||||||
elif min is None:
|
|
||||||
assert max is not None
|
|
||||||
assert span is not None
|
|
||||||
center = max - 0.5 * span
|
|
||||||
|
|
||||||
assert center is not None
|
if span_v is not None and span_v < 0:
|
||||||
assert span is not None
|
raise GridError('span must be non-negative')
|
||||||
if hasattr(center, '__len__'):
|
|
||||||
assert len(center) == 1
|
if min_v is not None and max_v is not None:
|
||||||
if hasattr(span, '__len__'):
|
if max_v < min_v:
|
||||||
assert len(span) == 1
|
raise GridError('max must be greater than or equal to min')
|
||||||
self.center = center
|
center_v = 0.5 * (max_v + min_v)
|
||||||
self.span = span
|
span_v = max_v - min_v
|
||||||
|
elif center_v is not None and min_v is not None:
|
||||||
|
span_v = 2 * (center_v - min_v)
|
||||||
|
if span_v < 0:
|
||||||
|
raise GridError('min must be less than or equal to center')
|
||||||
|
elif center_v is not None and max_v is not None:
|
||||||
|
span_v = 2 * (max_v - center_v)
|
||||||
|
if span_v < 0:
|
||||||
|
raise GridError('center must be less than or equal to max')
|
||||||
|
elif min_v is not None and span_v is not None:
|
||||||
|
center_v = min_v + 0.5 * span_v
|
||||||
|
elif max_v is not None and span_v is not None:
|
||||||
|
center_v = max_v - 0.5 * span_v
|
||||||
|
|
||||||
|
if center_v is None or span_v is None:
|
||||||
|
raise GridError('Unable to construct extent from the provided values')
|
||||||
|
|
||||||
|
self.center = center_v
|
||||||
|
self.span = span_v
|
||||||
|
|
||||||
|
|
||||||
class SlabDict(TypedDict, total=False):
|
class SlabDict(TypedDict, total=False):
|
||||||
|
|
@ -231,4 +246,3 @@ class Plane(PlaneProtocol):
|
||||||
if hasattr(cpos, '__len__'):
|
if hasattr(cpos, '__len__'):
|
||||||
assert len(cpos) == 1
|
assert len(cpos) == 1
|
||||||
self.pos = cpos
|
self.pos = cpos
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue