Remove is_scalar() in favor of numpy.size(...)==1

cell_data
Jan Petykiewicz 3 years ago
parent 551da07f3e
commit ff5ffb2f40

@ -1,8 +0,0 @@
from typing import Any
def is_scalar(var: Any) -> bool:
"""
Alias for `not hasattr(var, "__len__")`
"""
return not hasattr(var, "__len__")

@ -6,7 +6,6 @@ from typing import List, Optional, Union, Sequence, Callable
import numpy # type: ignore
from float_raster import raster
from ._helpers import is_scalar
from . import GridError
@ -58,8 +57,8 @@ def draw_polygons(self,
if not polygon.shape[0] > 2:
raise GridError(malformed + 'must consist of more than 2 points')
if polygon.ndim > 2 and not numpy.unique(polygon[:, surface_normal]).size == 1:
raise GridError(malformed + 'must be in plane with surface normal %s'
% 'xyz'[surface_normal])
raise GridError(malformed + 'must be in plane with surface normal '
+ 'xyz'[surface_normal])
# Broadcast eps where necessary
if numpy.size(eps) == 1:
@ -108,8 +107,8 @@ def draw_polygons(self,
eps_i = eps[i](x0, y0, z0)
if not numpy.isfinite(eps_i).all():
raise GridError('Non-finite values in eps[%u]' % i)
elif not is_scalar(eps[i]):
raise GridError('Unsupported eps[{}]: {}'.format(i, type(eps[i])))
elif numpy.size(eps[i]) != 1:
else:
# eps[i] is scalar non-callable
eps_i = eps[i]
@ -228,7 +227,7 @@ def draw_slab(self,
if surface_normal not in range(3):
raise GridError('Invalid surface_normal direction')
if not is_scalar(center):
if numpy.size(center) != 1:
center = numpy.squeeze(center)
if len(center) == 3:
center = center[surface_normal]

@ -7,7 +7,6 @@ import pickle
import warnings
import copy
from ._helpers import is_scalar
from . import GridError

@ -5,7 +5,6 @@ from typing import Dict, Optional, Union, Any
import numpy # type: ignore
from ._helpers import is_scalar
from . import GridError
# .visualize_* uses matplotlib
@ -34,14 +33,14 @@ def get_slice(self,
Returns:
Array containing the portion of the grid.
"""
if not is_scalar(center) and numpy.isreal(center):
if numpy.size(center) != 1 or not numpy.isreal(center):
raise GridError('center must be a real scalar')
sp = round(sample_period)
if sp <= 0:
raise GridError('sample_period must be positive')
if not is_scalar(which_shifts) or which_shifts < 0:
if numpy.size(which_shifts) != 1 or which_shifts < 0:
raise GridError('Invalid which_shifts')
if surface_normal not in range(3):

Loading…
Cancel
Save