diff --git a/fdfd_tools/operators.py b/fdfd_tools/operators.py index 7a72d8e..2fb215c 100644 --- a/fdfd_tools/operators.py +++ b/fdfd_tools/operators.py @@ -215,12 +215,13 @@ def m2j(omega: complex, return op -def rotation(axis: int, shape: List[int]) -> sparse.spmatrix: +def rotation(axis: int, shape: List[int], shift_distance: int=1) -> sparse.spmatrix: """ Utility operator for performing a circular shift along a specified axis by 1 element. :param axis: Axis to shift along. x=0, y=1, z=2 :param shape: Shape of the grid being shifted + :param shift_distance: Number of cells to shift by. May be negative. Default 1. :return: Sparse matrix for performing the circular shift """ if len(shape) not in (2, 3): @@ -228,12 +229,11 @@ def rotation(axis: int, shape: List[int]) -> sparse.spmatrix: if axis not in range(len(shape)): raise Exception('Invalid direction: {}, shape is {}'.format(axis, shape)) - n = numpy.prod(shape) - - shifts = [1 if k == axis else 0 for k in range(3)] + shifts = [abs(shift_distance) if a == axis else 0 for a in range(3)] shifted_diags = [(numpy.arange(n) + s) % n for n, s in zip(shape, shifts)] ijk = numpy.meshgrid(*shifted_diags, indexing='ij') + n = numpy.prod(shape) i_ind = numpy.arange(n) j_ind = ijk[0] + ijk[1] * shape[0] if len(shape) == 3: @@ -241,8 +241,52 @@ def rotation(axis: int, shape: List[int]) -> sparse.spmatrix: vij = (numpy.ones(n), (i_ind, j_ind.flatten(order='F'))) - D = sparse.csr_matrix(vij, shape=(n, n)) - return D + d = sparse.csr_matrix(vij, shape=(n, n)) + + if shift_distance < 0: + d = d.T + + return d + + +def shift_with_mirror(axis: int, shape: List[int], shift_distance: int=1) -> sparse.spmatrix: + """ + Utility operator for performing an n-element shift along a specified axis, with mirror + boundary conditions applied to the cells beyond the receding edge. + + :param axis: Axis to shift along. x=0, y=1, z=2 + :param shape: Shape of the grid being shifted + :param shift_distance: Number of cells to shift by. May be negative. Default 1. + :return: Sparse matrix for performing the circular shift + """ + if len(shape) not in (2, 3): + raise Exception('Invalid shape: {}'.format(shape)) + if axis not in range(len(shape)): + raise Exception('Invalid direction: {}, shape is {}'.format(axis, shape)) + if shift_distance >= shape[axis]: + raise Exception('Shift ({}) is too large for axis {} of size {}'.format( + shift_distance, axis, shape[axis])) + + def mirrored_range(n, s): + v = numpy.arange(n) + s + v = numpy.where(v >= n, 2 * n - v - 1, v) + v = numpy.where(v < 0, - 1 - v, v) + return v + + shifts = [shift_distance if a == axis else 0 for a in range(3)] + shifted_diags = [mirrored_range(n, s) for n, s in zip(shape, shifts)] + ijk = numpy.meshgrid(*shifted_diags, indexing='ij') + + n = numpy.prod(shape) + i_ind = numpy.arange(n) + j_ind = ijk[0] + ijk[1] * shape[0] + if len(shape) == 3: + j_ind += ijk[2] * shape[0] * shape[1] + + vij = (numpy.ones(n), (i_ind, j_ind.flatten(order='F'))) + + d = sparse.csr_matrix(vij, shape=(n, n)) + return d def deriv_forward(dx_e: List[numpy.ndarray]) -> List[sparse.spmatrix]: @@ -258,7 +302,7 @@ def deriv_forward(dx_e: List[numpy.ndarray]) -> List[sparse.spmatrix]: dx_e_expanded = numpy.meshgrid(*dx_e, indexing='ij') def deriv(axis): - return rotation(axis, shape) - sparse.eye(n) + return rotation(axis, shape, 1) - sparse.eye(n) Ds = [sparse.diags(+1 / dx.flatten(order='F')) @ deriv(a) for a, dx in enumerate(dx_e_expanded)] @@ -279,9 +323,9 @@ def deriv_back(dx_h: List[numpy.ndarray]) -> List[sparse.spmatrix]: dx_h_expanded = numpy.meshgrid(*dx_h, indexing='ij') def deriv(axis): - return rotation(axis, shape) - sparse.eye(n) + return rotation(axis, shape, -1) - sparse.eye(n) - Ds = [sparse.diags(-1 / dx.flatten(order='F')) @ deriv(a).T + Ds = [sparse.diags(-1 / dx.flatten(order='F')) @ deriv(a) for a, dx in enumerate(dx_h_expanded)] return Ds