From ff5ffb2f40d27e3f2502213d746147fdda4fbc18 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Sun, 24 Oct 2021 19:05:13 -0700 Subject: [PATCH] Remove is_scalar() in favor of numpy.size(...)==1 --- gridlock/_helpers.py | 8 -------- gridlock/draw.py | 9 ++++----- gridlock/grid.py | 1 - gridlock/read.py | 5 ++--- 4 files changed, 6 insertions(+), 17 deletions(-) delete mode 100644 gridlock/_helpers.py diff --git a/gridlock/_helpers.py b/gridlock/_helpers.py deleted file mode 100644 index 5c7f794..0000000 --- a/gridlock/_helpers.py +++ /dev/null @@ -1,8 +0,0 @@ -from typing import Any - - -def is_scalar(var: Any) -> bool: - """ - Alias for `not hasattr(var, "__len__")` - """ - return not hasattr(var, "__len__") diff --git a/gridlock/draw.py b/gridlock/draw.py index 2dd18d7..848e246 100644 --- a/gridlock/draw.py +++ b/gridlock/draw.py @@ -6,7 +6,6 @@ from typing import List, Optional, Union, Sequence, Callable import numpy # type: ignore from float_raster import raster -from ._helpers import is_scalar from . import GridError @@ -58,8 +57,8 @@ def draw_polygons(self, if not polygon.shape[0] > 2: raise GridError(malformed + 'must consist of more than 2 points') if polygon.ndim > 2 and not numpy.unique(polygon[:, surface_normal]).size == 1: - raise GridError(malformed + 'must be in plane with surface normal %s' - % 'xyz'[surface_normal]) + raise GridError(malformed + 'must be in plane with surface normal ' + + 'xyz'[surface_normal]) # Broadcast eps where necessary if numpy.size(eps) == 1: @@ -108,8 +107,8 @@ def draw_polygons(self, eps_i = eps[i](x0, y0, z0) if not numpy.isfinite(eps_i).all(): raise GridError('Non-finite values in eps[%u]' % i) - elif not is_scalar(eps[i]): raise GridError('Unsupported eps[{}]: {}'.format(i, type(eps[i]))) + elif numpy.size(eps[i]) != 1: else: # eps[i] is scalar non-callable eps_i = eps[i] @@ -228,7 +227,7 @@ def draw_slab(self, if surface_normal not in range(3): raise GridError('Invalid surface_normal direction') - if not is_scalar(center): + if numpy.size(center) != 1: center = numpy.squeeze(center) if len(center) == 3: center = center[surface_normal] diff --git a/gridlock/grid.py b/gridlock/grid.py index 3386823..f18390d 100644 --- a/gridlock/grid.py +++ b/gridlock/grid.py @@ -7,7 +7,6 @@ import pickle import warnings import copy -from ._helpers import is_scalar from . import GridError diff --git a/gridlock/read.py b/gridlock/read.py index 472508e..aa059d5 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -5,7 +5,6 @@ from typing import Dict, Optional, Union, Any import numpy # type: ignore -from ._helpers import is_scalar from . import GridError # .visualize_* uses matplotlib @@ -34,14 +33,14 @@ def get_slice(self, Returns: Array containing the portion of the grid. """ - if not is_scalar(center) and numpy.isreal(center): + if numpy.size(center) != 1 or not numpy.isreal(center): raise GridError('center must be a real scalar') sp = round(sample_period) if sp <= 0: raise GridError('sample_period must be positive') - if not is_scalar(which_shifts) or which_shifts < 0: + if numpy.size(which_shifts) != 1 or which_shifts < 0: raise GridError('Invalid which_shifts') if surface_normal not in range(3):