[Grid] enable negative shifts
This commit is contained in:
parent
066ca8f3b8
commit
85ae6e66cd
3 changed files with 74 additions and 23 deletions
|
|
@ -76,6 +76,21 @@ class GridBase(Protocol):
|
|||
el = [0 if p else -1 for p in self.periodic]
|
||||
return [numpy.hstack((self.dxyz[a], self.dxyz[a][e])) for a, e in zip(range(3), el, strict=True)]
|
||||
|
||||
def _shifted_edge_dxyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]:
|
||||
if which_shifts is None:
|
||||
return self.dxyz_with_ghost
|
||||
|
||||
shifts = self.shifts[which_shifts, :]
|
||||
edge_dxyz = []
|
||||
for a in range(3):
|
||||
if shifts[a] < 0:
|
||||
ghost = self.dxyz[a][-1] if self.periodic[a] else self.dxyz[a][0]
|
||||
edge_dxyz.append(numpy.hstack((ghost, self.dxyz[a])))
|
||||
else:
|
||||
ghost = self.dxyz[a][0] if self.periodic[a] else self.dxyz[a][-1]
|
||||
edge_dxyz.append(numpy.hstack((self.dxyz[a], ghost)))
|
||||
return edge_dxyz
|
||||
|
||||
@property
|
||||
def center(self) -> NDArray[numpy.float64]:
|
||||
"""
|
||||
|
|
@ -115,15 +130,9 @@ class GridBase(Protocol):
|
|||
"""
|
||||
if which_shifts is None:
|
||||
return self.exyz
|
||||
dxyz = self.dxyz_with_ghost
|
||||
edge_dxyz = self._shifted_edge_dxyz(which_shifts)
|
||||
shifts = self.shifts[which_shifts, :]
|
||||
|
||||
# If shift is negative, use left cell's dx to determine shift
|
||||
for a in range(3):
|
||||
if shifts[a] < 0:
|
||||
dxyz[a] = numpy.roll(dxyz[a], 1)
|
||||
|
||||
return [self.exyz[a] + dxyz[a] * shifts[a] for a in range(3)]
|
||||
return [self.exyz[a] + edge_dxyz[a] * shifts[a] for a in range(3)]
|
||||
|
||||
def shifted_dxyz(self, which_shifts: int | None) -> list[NDArray]:
|
||||
"""
|
||||
|
|
@ -137,20 +146,7 @@ class GridBase(Protocol):
|
|||
"""
|
||||
if which_shifts is None:
|
||||
return self.dxyz
|
||||
shifts = self.shifts[which_shifts, :]
|
||||
dxyz = self.dxyz_with_ghost
|
||||
|
||||
# If shift is negative, use left cell's dx to determine size
|
||||
sdxyz = []
|
||||
for a in range(3):
|
||||
if shifts[a] < 0:
|
||||
roll_dxyz = numpy.roll(dxyz[a], 1)
|
||||
abs_shift = numpy.abs(shifts[a])
|
||||
sdxyz.append(roll_dxyz[:-1] * abs_shift + roll_dxyz[1:] * (1 - abs_shift))
|
||||
else:
|
||||
sdxyz.append(dxyz[a][:-1] * (1 - shifts[a]) + dxyz[a][1:] * shifts[a])
|
||||
|
||||
return sdxyz
|
||||
return [numpy.diff(exyz) for exyz in self.shifted_exyz(which_shifts)]
|
||||
|
||||
def shifted_xyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ class GridReadMixin(GridPosMixin):
|
|||
surface = numpy.delete(range(3), plane.axis)
|
||||
|
||||
# Extract indices and weights of planes
|
||||
center3 = numpy.insert([0, 0], plane.axis, (plane.pos,))
|
||||
center3 = numpy.insert([0.0, 0.0], plane.axis, (plane.pos,))
|
||||
center_index = self.pos2ind(center3, which_shifts,
|
||||
round_ind=False, check_bounds=False)[plane.axis]
|
||||
centers = numpy.unique([numpy.floor(center_index), numpy.ceil(center_index)]).astype(int)
|
||||
|
|
|
|||
|
|
@ -309,3 +309,58 @@ def test_visualize_isosurface_sampling_uses_preview_lattice(monkeypatch: pytest.
|
|||
assert_allclose(captured['zs'], [1.5, 1.5, 1.5])
|
||||
|
||||
pyplot.close(fig)
|
||||
|
||||
|
||||
|
||||
|
||||
def test_negative_shift_nonperiodic_edges_and_widths() -> None:
|
||||
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False])
|
||||
|
||||
assert_allclose(grid.shifted_exyz(0)[0], [-0.5, 0.5, 2.0])
|
||||
assert_allclose(grid.shifted_dxyz(0)[0], [1.0, 1.5])
|
||||
assert_allclose(grid.shifted_xyz(0)[0], [0.0, 1.25])
|
||||
|
||||
|
||||
def test_negative_shift_periodic_edges_and_widths() -> None:
|
||||
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[True, False, False])
|
||||
|
||||
assert_allclose(grid.shifted_exyz(0)[0], [-1.0, 0.5, 2.0])
|
||||
assert_allclose(grid.shifted_dxyz(0)[0], [1.5, 1.5])
|
||||
|
||||
|
||||
def test_negative_shift_coordinate_round_trip() -> None:
|
||||
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False])
|
||||
|
||||
ind = grid.pos2ind([1.25, 1.0, 0.5], 0, round_ind=False)
|
||||
pos = grid.ind2pos(ind, 0, round_ind=False)
|
||||
|
||||
assert_allclose(ind, [1.0, 0.0, 0.0])
|
||||
assert_allclose(pos, [1.25, 1.0, 0.5])
|
||||
|
||||
|
||||
def test_negative_shift_draw_cuboid_fractional_fill() -> None:
|
||||
grid = Grid([[0, 1, 3], [0, 1], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False])
|
||||
arr = grid.allocate(0)
|
||||
|
||||
grid.draw_cuboid(
|
||||
arr,
|
||||
x=dict(min=0, max=1),
|
||||
y=dict(min=0, max=1),
|
||||
z=dict(min=0, max=1),
|
||||
foreground=1,
|
||||
)
|
||||
|
||||
assert_allclose(arr[0, :, 0, 0], [0.5, 1 / 3])
|
||||
|
||||
|
||||
def test_negative_shift_get_slice_uses_shifted_centers() -> None:
|
||||
grid = Grid([[0, 1, 3], [0, 1, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False])
|
||||
cell_data = numpy.zeros(grid.cell_data_shape)
|
||||
cell_data[0, 1, :, 0] = [7, 9]
|
||||
x_center = float(grid.shifted_xyz(0)[0][1])
|
||||
|
||||
grid_slice = grid.get_slice(cell_data, Plane(x=x_center), which_shifts=0)
|
||||
|
||||
assert_allclose(grid_slice, [7, 9])
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue