Remove is_scalar() in favor of numpy.size(...)==1
This commit is contained in:
parent
551da07f3e
commit
ff5ffb2f40
@ -1,8 +0,0 @@
|
|||||||
from typing import Any
|
|
||||||
|
|
||||||
|
|
||||||
def is_scalar(var: Any) -> bool:
|
|
||||||
"""
|
|
||||||
Alias for `not hasattr(var, "__len__")`
|
|
||||||
"""
|
|
||||||
return not hasattr(var, "__len__")
|
|
@ -6,7 +6,6 @@ from typing import List, Optional, Union, Sequence, Callable
|
|||||||
import numpy # type: ignore
|
import numpy # type: ignore
|
||||||
from float_raster import raster
|
from float_raster import raster
|
||||||
|
|
||||||
from ._helpers import is_scalar
|
|
||||||
from . import GridError
|
from . import GridError
|
||||||
|
|
||||||
|
|
||||||
@ -58,8 +57,8 @@ def draw_polygons(self,
|
|||||||
if not polygon.shape[0] > 2:
|
if not polygon.shape[0] > 2:
|
||||||
raise GridError(malformed + 'must consist of more than 2 points')
|
raise GridError(malformed + 'must consist of more than 2 points')
|
||||||
if polygon.ndim > 2 and not numpy.unique(polygon[:, surface_normal]).size == 1:
|
if polygon.ndim > 2 and not numpy.unique(polygon[:, surface_normal]).size == 1:
|
||||||
raise GridError(malformed + 'must be in plane with surface normal %s'
|
raise GridError(malformed + 'must be in plane with surface normal '
|
||||||
% 'xyz'[surface_normal])
|
+ 'xyz'[surface_normal])
|
||||||
|
|
||||||
# Broadcast eps where necessary
|
# Broadcast eps where necessary
|
||||||
if numpy.size(eps) == 1:
|
if numpy.size(eps) == 1:
|
||||||
@ -108,8 +107,8 @@ def draw_polygons(self,
|
|||||||
eps_i = eps[i](x0, y0, z0)
|
eps_i = eps[i](x0, y0, z0)
|
||||||
if not numpy.isfinite(eps_i).all():
|
if not numpy.isfinite(eps_i).all():
|
||||||
raise GridError('Non-finite values in eps[%u]' % i)
|
raise GridError('Non-finite values in eps[%u]' % i)
|
||||||
elif not is_scalar(eps[i]):
|
|
||||||
raise GridError('Unsupported eps[{}]: {}'.format(i, type(eps[i])))
|
raise GridError('Unsupported eps[{}]: {}'.format(i, type(eps[i])))
|
||||||
|
elif numpy.size(eps[i]) != 1:
|
||||||
else:
|
else:
|
||||||
# eps[i] is scalar non-callable
|
# eps[i] is scalar non-callable
|
||||||
eps_i = eps[i]
|
eps_i = eps[i]
|
||||||
@ -228,7 +227,7 @@ def draw_slab(self,
|
|||||||
if surface_normal not in range(3):
|
if surface_normal not in range(3):
|
||||||
raise GridError('Invalid surface_normal direction')
|
raise GridError('Invalid surface_normal direction')
|
||||||
|
|
||||||
if not is_scalar(center):
|
if numpy.size(center) != 1:
|
||||||
center = numpy.squeeze(center)
|
center = numpy.squeeze(center)
|
||||||
if len(center) == 3:
|
if len(center) == 3:
|
||||||
center = center[surface_normal]
|
center = center[surface_normal]
|
||||||
|
@ -7,7 +7,6 @@ import pickle
|
|||||||
import warnings
|
import warnings
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
from ._helpers import is_scalar
|
|
||||||
from . import GridError
|
from . import GridError
|
||||||
|
|
||||||
|
|
||||||
|
@ -5,7 +5,6 @@ from typing import Dict, Optional, Union, Any
|
|||||||
|
|
||||||
import numpy # type: ignore
|
import numpy # type: ignore
|
||||||
|
|
||||||
from ._helpers import is_scalar
|
|
||||||
from . import GridError
|
from . import GridError
|
||||||
|
|
||||||
# .visualize_* uses matplotlib
|
# .visualize_* uses matplotlib
|
||||||
@ -34,14 +33,14 @@ def get_slice(self,
|
|||||||
Returns:
|
Returns:
|
||||||
Array containing the portion of the grid.
|
Array containing the portion of the grid.
|
||||||
"""
|
"""
|
||||||
if not is_scalar(center) and numpy.isreal(center):
|
if numpy.size(center) != 1 or not numpy.isreal(center):
|
||||||
raise GridError('center must be a real scalar')
|
raise GridError('center must be a real scalar')
|
||||||
|
|
||||||
sp = round(sample_period)
|
sp = round(sample_period)
|
||||||
if sp <= 0:
|
if sp <= 0:
|
||||||
raise GridError('sample_period must be positive')
|
raise GridError('sample_period must be positive')
|
||||||
|
|
||||||
if not is_scalar(which_shifts) or which_shifts < 0:
|
if numpy.size(which_shifts) != 1 or which_shifts < 0:
|
||||||
raise GridError('Invalid which_shifts')
|
raise GridError('Invalid which_shifts')
|
||||||
|
|
||||||
if surface_normal not in range(3):
|
if surface_normal not in range(3):
|
||||||
|
Loading…
Reference in New Issue
Block a user