diff --git a/gridlock/base.py b/gridlock/base.py index aca9c69..e68d955 100644 --- a/gridlock/base.py +++ b/gridlock/base.py @@ -76,6 +76,21 @@ class GridBase(Protocol): el = [0 if p else -1 for p in self.periodic] return [numpy.hstack((self.dxyz[a], self.dxyz[a][e])) for a, e in zip(range(3), el, strict=True)] + def _shifted_edge_dxyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]: + if which_shifts is None: + return self.dxyz_with_ghost + + shifts = self.shifts[which_shifts, :] + edge_dxyz = [] + for a in range(3): + if shifts[a] < 0: + ghost = self.dxyz[a][-1] if self.periodic[a] else self.dxyz[a][0] + edge_dxyz.append(numpy.hstack((ghost, self.dxyz[a]))) + else: + ghost = self.dxyz[a][0] if self.periodic[a] else self.dxyz[a][-1] + edge_dxyz.append(numpy.hstack((self.dxyz[a], ghost))) + return edge_dxyz + @property def center(self) -> NDArray[numpy.float64]: """ @@ -115,15 +130,9 @@ class GridBase(Protocol): """ if which_shifts is None: return self.exyz - dxyz = self.dxyz_with_ghost + edge_dxyz = self._shifted_edge_dxyz(which_shifts) shifts = self.shifts[which_shifts, :] - - # If shift is negative, use left cell's dx to determine shift - for a in range(3): - if shifts[a] < 0: - dxyz[a] = numpy.roll(dxyz[a], 1) - - return [self.exyz[a] + dxyz[a] * shifts[a] for a in range(3)] + return [self.exyz[a] + edge_dxyz[a] * shifts[a] for a in range(3)] def shifted_dxyz(self, which_shifts: int | None) -> list[NDArray]: """ @@ -137,20 +146,7 @@ class GridBase(Protocol): """ if which_shifts is None: return self.dxyz - shifts = self.shifts[which_shifts, :] - dxyz = self.dxyz_with_ghost - - # If shift is negative, use left cell's dx to determine size - sdxyz = [] - for a in range(3): - if shifts[a] < 0: - roll_dxyz = numpy.roll(dxyz[a], 1) - abs_shift = numpy.abs(shifts[a]) - sdxyz.append(roll_dxyz[:-1] * abs_shift + roll_dxyz[1:] * (1 - abs_shift)) - else: - sdxyz.append(dxyz[a][:-1] * (1 - shifts[a]) + dxyz[a][1:] * shifts[a]) - - return sdxyz + return [numpy.diff(exyz) for exyz in self.shifted_exyz(which_shifts)] def shifted_xyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]: """ diff --git a/gridlock/read.py b/gridlock/read.py index 9be52b1..f8a40a1 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -73,7 +73,7 @@ class GridReadMixin(GridPosMixin): surface = numpy.delete(range(3), plane.axis) # Extract indices and weights of planes - center3 = numpy.insert([0, 0], plane.axis, (plane.pos,)) + center3 = numpy.insert([0.0, 0.0], plane.axis, (plane.pos,)) center_index = self.pos2ind(center3, which_shifts, round_ind=False, check_bounds=False)[plane.axis] centers = numpy.unique([numpy.floor(center_index), numpy.ceil(center_index)]).astype(int) diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index e7b3b28..b4929a4 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -309,3 +309,58 @@ def test_visualize_isosurface_sampling_uses_preview_lattice(monkeypatch: pytest. assert_allclose(captured['zs'], [1.5, 1.5, 1.5]) pyplot.close(fig) + + + + +def test_negative_shift_nonperiodic_edges_and_widths() -> None: + grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False]) + + assert_allclose(grid.shifted_exyz(0)[0], [-0.5, 0.5, 2.0]) + assert_allclose(grid.shifted_dxyz(0)[0], [1.0, 1.5]) + assert_allclose(grid.shifted_xyz(0)[0], [0.0, 1.25]) + + +def test_negative_shift_periodic_edges_and_widths() -> None: + grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[True, False, False]) + + assert_allclose(grid.shifted_exyz(0)[0], [-1.0, 0.5, 2.0]) + assert_allclose(grid.shifted_dxyz(0)[0], [1.5, 1.5]) + + +def test_negative_shift_coordinate_round_trip() -> None: + grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False]) + + ind = grid.pos2ind([1.25, 1.0, 0.5], 0, round_ind=False) + pos = grid.ind2pos(ind, 0, round_ind=False) + + assert_allclose(ind, [1.0, 0.0, 0.0]) + assert_allclose(pos, [1.25, 1.0, 0.5]) + + +def test_negative_shift_draw_cuboid_fractional_fill() -> None: + grid = Grid([[0, 1, 3], [0, 1], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False]) + arr = grid.allocate(0) + + grid.draw_cuboid( + arr, + x=dict(min=0, max=1), + y=dict(min=0, max=1), + z=dict(min=0, max=1), + foreground=1, + ) + + assert_allclose(arr[0, :, 0, 0], [0.5, 1 / 3]) + + +def test_negative_shift_get_slice_uses_shifted_centers() -> None: + grid = Grid([[0, 1, 3], [0, 1, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False]) + cell_data = numpy.zeros(grid.cell_data_shape) + cell_data[0, 1, :, 0] = [7, 9] + x_center = float(grid.shifted_xyz(0)[0][1]) + + grid_slice = grid.get_slice(cell_data, Plane(x=x_center), which_shifts=0) + + assert_allclose(grid_slice, [7, 9]) + +