Add shift_with_mirror, and add shift_distance argument to rotation()
This commit is contained in:
parent
35555cf4b3
commit
8f202fd061
@ -215,12 +215,13 @@ def m2j(omega: complex,
|
|||||||
return op
|
return op
|
||||||
|
|
||||||
|
|
||||||
def rotation(axis: int, shape: List[int]) -> sparse.spmatrix:
|
def rotation(axis: int, shape: List[int], shift_distance: int=1) -> sparse.spmatrix:
|
||||||
"""
|
"""
|
||||||
Utility operator for performing a circular shift along a specified axis by 1 element.
|
Utility operator for performing a circular shift along a specified axis by 1 element.
|
||||||
|
|
||||||
:param axis: Axis to shift along. x=0, y=1, z=2
|
:param axis: Axis to shift along. x=0, y=1, z=2
|
||||||
:param shape: Shape of the grid being shifted
|
:param shape: Shape of the grid being shifted
|
||||||
|
:param shift_distance: Number of cells to shift by. May be negative. Default 1.
|
||||||
:return: Sparse matrix for performing the circular shift
|
:return: Sparse matrix for performing the circular shift
|
||||||
"""
|
"""
|
||||||
if len(shape) not in (2, 3):
|
if len(shape) not in (2, 3):
|
||||||
@ -228,12 +229,11 @@ def rotation(axis: int, shape: List[int]) -> sparse.spmatrix:
|
|||||||
if axis not in range(len(shape)):
|
if axis not in range(len(shape)):
|
||||||
raise Exception('Invalid direction: {}, shape is {}'.format(axis, shape))
|
raise Exception('Invalid direction: {}, shape is {}'.format(axis, shape))
|
||||||
|
|
||||||
n = numpy.prod(shape)
|
shifts = [abs(shift_distance) if a == axis else 0 for a in range(3)]
|
||||||
|
|
||||||
shifts = [1 if k == axis else 0 for k in range(3)]
|
|
||||||
shifted_diags = [(numpy.arange(n) + s) % n for n, s in zip(shape, shifts)]
|
shifted_diags = [(numpy.arange(n) + s) % n for n, s in zip(shape, shifts)]
|
||||||
ijk = numpy.meshgrid(*shifted_diags, indexing='ij')
|
ijk = numpy.meshgrid(*shifted_diags, indexing='ij')
|
||||||
|
|
||||||
|
n = numpy.prod(shape)
|
||||||
i_ind = numpy.arange(n)
|
i_ind = numpy.arange(n)
|
||||||
j_ind = ijk[0] + ijk[1] * shape[0]
|
j_ind = ijk[0] + ijk[1] * shape[0]
|
||||||
if len(shape) == 3:
|
if len(shape) == 3:
|
||||||
@ -241,8 +241,52 @@ def rotation(axis: int, shape: List[int]) -> sparse.spmatrix:
|
|||||||
|
|
||||||
vij = (numpy.ones(n), (i_ind, j_ind.flatten(order='F')))
|
vij = (numpy.ones(n), (i_ind, j_ind.flatten(order='F')))
|
||||||
|
|
||||||
D = sparse.csr_matrix(vij, shape=(n, n))
|
d = sparse.csr_matrix(vij, shape=(n, n))
|
||||||
return D
|
|
||||||
|
if shift_distance < 0:
|
||||||
|
d = d.T
|
||||||
|
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
def shift_with_mirror(axis: int, shape: List[int], shift_distance: int=1) -> sparse.spmatrix:
|
||||||
|
"""
|
||||||
|
Utility operator for performing an n-element shift along a specified axis, with mirror
|
||||||
|
boundary conditions applied to the cells beyond the receding edge.
|
||||||
|
|
||||||
|
:param axis: Axis to shift along. x=0, y=1, z=2
|
||||||
|
:param shape: Shape of the grid being shifted
|
||||||
|
:param shift_distance: Number of cells to shift by. May be negative. Default 1.
|
||||||
|
:return: Sparse matrix for performing the circular shift
|
||||||
|
"""
|
||||||
|
if len(shape) not in (2, 3):
|
||||||
|
raise Exception('Invalid shape: {}'.format(shape))
|
||||||
|
if axis not in range(len(shape)):
|
||||||
|
raise Exception('Invalid direction: {}, shape is {}'.format(axis, shape))
|
||||||
|
if shift_distance >= shape[axis]:
|
||||||
|
raise Exception('Shift ({}) is too large for axis {} of size {}'.format(
|
||||||
|
shift_distance, axis, shape[axis]))
|
||||||
|
|
||||||
|
def mirrored_range(n, s):
|
||||||
|
v = numpy.arange(n) + s
|
||||||
|
v = numpy.where(v >= n, 2 * n - v - 1, v)
|
||||||
|
v = numpy.where(v < 0, - 1 - v, v)
|
||||||
|
return v
|
||||||
|
|
||||||
|
shifts = [shift_distance if a == axis else 0 for a in range(3)]
|
||||||
|
shifted_diags = [mirrored_range(n, s) for n, s in zip(shape, shifts)]
|
||||||
|
ijk = numpy.meshgrid(*shifted_diags, indexing='ij')
|
||||||
|
|
||||||
|
n = numpy.prod(shape)
|
||||||
|
i_ind = numpy.arange(n)
|
||||||
|
j_ind = ijk[0] + ijk[1] * shape[0]
|
||||||
|
if len(shape) == 3:
|
||||||
|
j_ind += ijk[2] * shape[0] * shape[1]
|
||||||
|
|
||||||
|
vij = (numpy.ones(n), (i_ind, j_ind.flatten(order='F')))
|
||||||
|
|
||||||
|
d = sparse.csr_matrix(vij, shape=(n, n))
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
def deriv_forward(dx_e: List[numpy.ndarray]) -> List[sparse.spmatrix]:
|
def deriv_forward(dx_e: List[numpy.ndarray]) -> List[sparse.spmatrix]:
|
||||||
@ -258,7 +302,7 @@ def deriv_forward(dx_e: List[numpy.ndarray]) -> List[sparse.spmatrix]:
|
|||||||
dx_e_expanded = numpy.meshgrid(*dx_e, indexing='ij')
|
dx_e_expanded = numpy.meshgrid(*dx_e, indexing='ij')
|
||||||
|
|
||||||
def deriv(axis):
|
def deriv(axis):
|
||||||
return rotation(axis, shape) - sparse.eye(n)
|
return rotation(axis, shape, 1) - sparse.eye(n)
|
||||||
|
|
||||||
Ds = [sparse.diags(+1 / dx.flatten(order='F')) @ deriv(a)
|
Ds = [sparse.diags(+1 / dx.flatten(order='F')) @ deriv(a)
|
||||||
for a, dx in enumerate(dx_e_expanded)]
|
for a, dx in enumerate(dx_e_expanded)]
|
||||||
@ -279,9 +323,9 @@ def deriv_back(dx_h: List[numpy.ndarray]) -> List[sparse.spmatrix]:
|
|||||||
dx_h_expanded = numpy.meshgrid(*dx_h, indexing='ij')
|
dx_h_expanded = numpy.meshgrid(*dx_h, indexing='ij')
|
||||||
|
|
||||||
def deriv(axis):
|
def deriv(axis):
|
||||||
return rotation(axis, shape) - sparse.eye(n)
|
return rotation(axis, shape, -1) - sparse.eye(n)
|
||||||
|
|
||||||
Ds = [sparse.diags(-1 / dx.flatten(order='F')) @ deriv(a).T
|
Ds = [sparse.diags(-1 / dx.flatten(order='F')) @ deriv(a)
|
||||||
for a, dx in enumerate(dx_h_expanded)]
|
for a, dx in enumerate(dx_h_expanded)]
|
||||||
|
|
||||||
return Ds
|
return Ds
|
||||||
|
Loading…
Reference in New Issue
Block a user