[Grid] enable negative shifts

This commit is contained in:
Jan Petykiewicz 2026-04-21 19:58:57 -07:00
commit 85ae6e66cd
3 changed files with 74 additions and 23 deletions

View file

@ -76,6 +76,21 @@ class GridBase(Protocol):
el = [0 if p else -1 for p in self.periodic]
return [numpy.hstack((self.dxyz[a], self.dxyz[a][e])) for a, e in zip(range(3), el, strict=True)]
def _shifted_edge_dxyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]:
if which_shifts is None:
return self.dxyz_with_ghost
shifts = self.shifts[which_shifts, :]
edge_dxyz = []
for a in range(3):
if shifts[a] < 0:
ghost = self.dxyz[a][-1] if self.periodic[a] else self.dxyz[a][0]
edge_dxyz.append(numpy.hstack((ghost, self.dxyz[a])))
else:
ghost = self.dxyz[a][0] if self.periodic[a] else self.dxyz[a][-1]
edge_dxyz.append(numpy.hstack((self.dxyz[a], ghost)))
return edge_dxyz
@property
def center(self) -> NDArray[numpy.float64]:
"""
@ -115,15 +130,9 @@ class GridBase(Protocol):
"""
if which_shifts is None:
return self.exyz
dxyz = self.dxyz_with_ghost
edge_dxyz = self._shifted_edge_dxyz(which_shifts)
shifts = self.shifts[which_shifts, :]
# If shift is negative, use left cell's dx to determine shift
for a in range(3):
if shifts[a] < 0:
dxyz[a] = numpy.roll(dxyz[a], 1)
return [self.exyz[a] + dxyz[a] * shifts[a] for a in range(3)]
return [self.exyz[a] + edge_dxyz[a] * shifts[a] for a in range(3)]
def shifted_dxyz(self, which_shifts: int | None) -> list[NDArray]:
"""
@ -137,20 +146,7 @@ class GridBase(Protocol):
"""
if which_shifts is None:
return self.dxyz
shifts = self.shifts[which_shifts, :]
dxyz = self.dxyz_with_ghost
# If shift is negative, use left cell's dx to determine size
sdxyz = []
for a in range(3):
if shifts[a] < 0:
roll_dxyz = numpy.roll(dxyz[a], 1)
abs_shift = numpy.abs(shifts[a])
sdxyz.append(roll_dxyz[:-1] * abs_shift + roll_dxyz[1:] * (1 - abs_shift))
else:
sdxyz.append(dxyz[a][:-1] * (1 - shifts[a]) + dxyz[a][1:] * shifts[a])
return sdxyz
return [numpy.diff(exyz) for exyz in self.shifted_exyz(which_shifts)]
def shifted_xyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]:
"""

View file

@ -73,7 +73,7 @@ class GridReadMixin(GridPosMixin):
surface = numpy.delete(range(3), plane.axis)
# Extract indices and weights of planes
center3 = numpy.insert([0, 0], plane.axis, (plane.pos,))
center3 = numpy.insert([0.0, 0.0], plane.axis, (plane.pos,))
center_index = self.pos2ind(center3, which_shifts,
round_ind=False, check_bounds=False)[plane.axis]
centers = numpy.unique([numpy.floor(center_index), numpy.ceil(center_index)]).astype(int)

View file

@ -309,3 +309,58 @@ def test_visualize_isosurface_sampling_uses_preview_lattice(monkeypatch: pytest.
assert_allclose(captured['zs'], [1.5, 1.5, 1.5])
pyplot.close(fig)
def test_negative_shift_nonperiodic_edges_and_widths() -> None:
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False])
assert_allclose(grid.shifted_exyz(0)[0], [-0.5, 0.5, 2.0])
assert_allclose(grid.shifted_dxyz(0)[0], [1.0, 1.5])
assert_allclose(grid.shifted_xyz(0)[0], [0.0, 1.25])
def test_negative_shift_periodic_edges_and_widths() -> None:
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[True, False, False])
assert_allclose(grid.shifted_exyz(0)[0], [-1.0, 0.5, 2.0])
assert_allclose(grid.shifted_dxyz(0)[0], [1.5, 1.5])
def test_negative_shift_coordinate_round_trip() -> None:
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False])
ind = grid.pos2ind([1.25, 1.0, 0.5], 0, round_ind=False)
pos = grid.ind2pos(ind, 0, round_ind=False)
assert_allclose(ind, [1.0, 0.0, 0.0])
assert_allclose(pos, [1.25, 1.0, 0.5])
def test_negative_shift_draw_cuboid_fractional_fill() -> None:
grid = Grid([[0, 1, 3], [0, 1], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False])
arr = grid.allocate(0)
grid.draw_cuboid(
arr,
x=dict(min=0, max=1),
y=dict(min=0, max=1),
z=dict(min=0, max=1),
foreground=1,
)
assert_allclose(arr[0, :, 0, 0], [0.5, 1 / 3])
def test_negative_shift_get_slice_uses_shifted_centers() -> None:
grid = Grid([[0, 1, 3], [0, 1, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False])
cell_data = numpy.zeros(grid.cell_data_shape)
cell_data[0, 1, :, 0] = [7, 9]
x_center = float(grid.shifted_xyz(0)[0][1])
grid_slice = grid.get_slice(cell_data, Plane(x=x_center), which_shifts=0)
assert_allclose(grid_slice, [7, 9])