[Grid] enable negative shifts
This commit is contained in:
parent
066ca8f3b8
commit
85ae6e66cd
3 changed files with 74 additions and 23 deletions
|
|
@ -76,6 +76,21 @@ class GridBase(Protocol):
|
||||||
el = [0 if p else -1 for p in self.periodic]
|
el = [0 if p else -1 for p in self.periodic]
|
||||||
return [numpy.hstack((self.dxyz[a], self.dxyz[a][e])) for a, e in zip(range(3), el, strict=True)]
|
return [numpy.hstack((self.dxyz[a], self.dxyz[a][e])) for a, e in zip(range(3), el, strict=True)]
|
||||||
|
|
||||||
|
def _shifted_edge_dxyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]:
|
||||||
|
if which_shifts is None:
|
||||||
|
return self.dxyz_with_ghost
|
||||||
|
|
||||||
|
shifts = self.shifts[which_shifts, :]
|
||||||
|
edge_dxyz = []
|
||||||
|
for a in range(3):
|
||||||
|
if shifts[a] < 0:
|
||||||
|
ghost = self.dxyz[a][-1] if self.periodic[a] else self.dxyz[a][0]
|
||||||
|
edge_dxyz.append(numpy.hstack((ghost, self.dxyz[a])))
|
||||||
|
else:
|
||||||
|
ghost = self.dxyz[a][0] if self.periodic[a] else self.dxyz[a][-1]
|
||||||
|
edge_dxyz.append(numpy.hstack((self.dxyz[a], ghost)))
|
||||||
|
return edge_dxyz
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def center(self) -> NDArray[numpy.float64]:
|
def center(self) -> NDArray[numpy.float64]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -115,15 +130,9 @@ class GridBase(Protocol):
|
||||||
"""
|
"""
|
||||||
if which_shifts is None:
|
if which_shifts is None:
|
||||||
return self.exyz
|
return self.exyz
|
||||||
dxyz = self.dxyz_with_ghost
|
edge_dxyz = self._shifted_edge_dxyz(which_shifts)
|
||||||
shifts = self.shifts[which_shifts, :]
|
shifts = self.shifts[which_shifts, :]
|
||||||
|
return [self.exyz[a] + edge_dxyz[a] * shifts[a] for a in range(3)]
|
||||||
# If shift is negative, use left cell's dx to determine shift
|
|
||||||
for a in range(3):
|
|
||||||
if shifts[a] < 0:
|
|
||||||
dxyz[a] = numpy.roll(dxyz[a], 1)
|
|
||||||
|
|
||||||
return [self.exyz[a] + dxyz[a] * shifts[a] for a in range(3)]
|
|
||||||
|
|
||||||
def shifted_dxyz(self, which_shifts: int | None) -> list[NDArray]:
|
def shifted_dxyz(self, which_shifts: int | None) -> list[NDArray]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -137,20 +146,7 @@ class GridBase(Protocol):
|
||||||
"""
|
"""
|
||||||
if which_shifts is None:
|
if which_shifts is None:
|
||||||
return self.dxyz
|
return self.dxyz
|
||||||
shifts = self.shifts[which_shifts, :]
|
return [numpy.diff(exyz) for exyz in self.shifted_exyz(which_shifts)]
|
||||||
dxyz = self.dxyz_with_ghost
|
|
||||||
|
|
||||||
# If shift is negative, use left cell's dx to determine size
|
|
||||||
sdxyz = []
|
|
||||||
for a in range(3):
|
|
||||||
if shifts[a] < 0:
|
|
||||||
roll_dxyz = numpy.roll(dxyz[a], 1)
|
|
||||||
abs_shift = numpy.abs(shifts[a])
|
|
||||||
sdxyz.append(roll_dxyz[:-1] * abs_shift + roll_dxyz[1:] * (1 - abs_shift))
|
|
||||||
else:
|
|
||||||
sdxyz.append(dxyz[a][:-1] * (1 - shifts[a]) + dxyz[a][1:] * shifts[a])
|
|
||||||
|
|
||||||
return sdxyz
|
|
||||||
|
|
||||||
def shifted_xyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]:
|
def shifted_xyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -73,7 +73,7 @@ class GridReadMixin(GridPosMixin):
|
||||||
surface = numpy.delete(range(3), plane.axis)
|
surface = numpy.delete(range(3), plane.axis)
|
||||||
|
|
||||||
# Extract indices and weights of planes
|
# Extract indices and weights of planes
|
||||||
center3 = numpy.insert([0, 0], plane.axis, (plane.pos,))
|
center3 = numpy.insert([0.0, 0.0], plane.axis, (plane.pos,))
|
||||||
center_index = self.pos2ind(center3, which_shifts,
|
center_index = self.pos2ind(center3, which_shifts,
|
||||||
round_ind=False, check_bounds=False)[plane.axis]
|
round_ind=False, check_bounds=False)[plane.axis]
|
||||||
centers = numpy.unique([numpy.floor(center_index), numpy.ceil(center_index)]).astype(int)
|
centers = numpy.unique([numpy.floor(center_index), numpy.ceil(center_index)]).astype(int)
|
||||||
|
|
|
||||||
|
|
@ -309,3 +309,58 @@ def test_visualize_isosurface_sampling_uses_preview_lattice(monkeypatch: pytest.
|
||||||
assert_allclose(captured['zs'], [1.5, 1.5, 1.5])
|
assert_allclose(captured['zs'], [1.5, 1.5, 1.5])
|
||||||
|
|
||||||
pyplot.close(fig)
|
pyplot.close(fig)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_negative_shift_nonperiodic_edges_and_widths() -> None:
|
||||||
|
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False])
|
||||||
|
|
||||||
|
assert_allclose(grid.shifted_exyz(0)[0], [-0.5, 0.5, 2.0])
|
||||||
|
assert_allclose(grid.shifted_dxyz(0)[0], [1.0, 1.5])
|
||||||
|
assert_allclose(grid.shifted_xyz(0)[0], [0.0, 1.25])
|
||||||
|
|
||||||
|
|
||||||
|
def test_negative_shift_periodic_edges_and_widths() -> None:
|
||||||
|
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[True, False, False])
|
||||||
|
|
||||||
|
assert_allclose(grid.shifted_exyz(0)[0], [-1.0, 0.5, 2.0])
|
||||||
|
assert_allclose(grid.shifted_dxyz(0)[0], [1.5, 1.5])
|
||||||
|
|
||||||
|
|
||||||
|
def test_negative_shift_coordinate_round_trip() -> None:
|
||||||
|
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False])
|
||||||
|
|
||||||
|
ind = grid.pos2ind([1.25, 1.0, 0.5], 0, round_ind=False)
|
||||||
|
pos = grid.ind2pos(ind, 0, round_ind=False)
|
||||||
|
|
||||||
|
assert_allclose(ind, [1.0, 0.0, 0.0])
|
||||||
|
assert_allclose(pos, [1.25, 1.0, 0.5])
|
||||||
|
|
||||||
|
|
||||||
|
def test_negative_shift_draw_cuboid_fractional_fill() -> None:
|
||||||
|
grid = Grid([[0, 1, 3], [0, 1], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False])
|
||||||
|
arr = grid.allocate(0)
|
||||||
|
|
||||||
|
grid.draw_cuboid(
|
||||||
|
arr,
|
||||||
|
x=dict(min=0, max=1),
|
||||||
|
y=dict(min=0, max=1),
|
||||||
|
z=dict(min=0, max=1),
|
||||||
|
foreground=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert_allclose(arr[0, :, 0, 0], [0.5, 1 / 3])
|
||||||
|
|
||||||
|
|
||||||
|
def test_negative_shift_get_slice_uses_shifted_centers() -> None:
|
||||||
|
grid = Grid([[0, 1, 3], [0, 1, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False])
|
||||||
|
cell_data = numpy.zeros(grid.cell_data_shape)
|
||||||
|
cell_data[0, 1, :, 0] = [7, 9]
|
||||||
|
x_center = float(grid.shifted_xyz(0)[0][1])
|
||||||
|
|
||||||
|
grid_slice = grid.get_slice(cell_data, Plane(x=x_center), which_shifts=0)
|
||||||
|
|
||||||
|
assert_allclose(grid_slice, [7, 9])
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue