forked from jan/fdfd_tools
Add solvers submodule and clean up examples.
Solvers submodule includes a generic solver in case you already have a sparse matrix solver, or in case you have no solver at all. Example file now uses alternate solvers if available, and has a nicer way of picking which solver gets used.
This commit is contained in:
parent
85880c859e
commit
ec674fe3f4
3 changed files with 153 additions and 16 deletions
|
|
@ -1,18 +1,22 @@
|
|||
import importlib
|
||||
import numpy
|
||||
from numpy.linalg import norm
|
||||
|
||||
from fdfd_tools import vec, unvec, waveguide_mode
|
||||
import fdfd_tools, fdfd_tools.functional, fdfd_tools.grid
|
||||
import fdfd_tools
|
||||
import fdfd_tools.functional
|
||||
import fdfd_tools.grid
|
||||
from fdfd_tools.solvers import generic as generic_solver
|
||||
|
||||
import gridlock
|
||||
|
||||
from matplotlib import pyplot
|
||||
|
||||
#import magma_fdfd
|
||||
from opencl_fdfd import cg_solver, csr
|
||||
|
||||
__author__ = 'Jan Petykiewicz'
|
||||
|
||||
|
||||
def test0():
|
||||
def test0(solver=generic_solver):
|
||||
dx = 50 # discretization (nm/cell)
|
||||
pml_thickness = 10 # (number of cells)
|
||||
|
||||
|
|
@ -59,21 +63,27 @@ def test0():
|
|||
J = [numpy.zeros_like(grid.grids[0], dtype=complex) for _ in range(3)]
|
||||
J[1][15, grid.shape[1]//2, grid.shape[2]//2] = 1e5
|
||||
|
||||
'''
|
||||
Solve!
|
||||
'''
|
||||
x = solver(J=vec(J), **sim_args)
|
||||
|
||||
A = fdfd_tools.functional.e_full(omega, dxes, vec(grid.grids)).tocsr()
|
||||
b = -1j * omega * vec(J)
|
||||
print('Norm of the residual is ', norm(A @ x - b))
|
||||
|
||||
x = solve_A(A, b)
|
||||
E = unvec(x, grid.shape)
|
||||
|
||||
print('Norm of the residual is {}'.format(numpy.linalg.norm(A.dot(x) - b)/numpy.linalg.norm(b)))
|
||||
|
||||
'''
|
||||
Plot results
|
||||
'''
|
||||
pyplot.figure()
|
||||
pyplot.pcolor(numpy.real(E[1][:, :, grid.shape[2]//2]), cmap='seismic')
|
||||
pyplot.axis('equal')
|
||||
pyplot.show()
|
||||
|
||||
|
||||
def test1():
|
||||
def test1(solver=generic_solver):
|
||||
dx = 40 # discretization (nm/cell)
|
||||
pml_thickness = 10 # (number of cells)
|
||||
|
||||
|
|
@ -142,17 +152,14 @@ def test1():
|
|||
'pmc': vec(pmcg.grids),
|
||||
}
|
||||
|
||||
x = solver(J=vec(J), **sim_args)
|
||||
|
||||
b = -1j * omega * vec(J)
|
||||
A = fdfd_tools.operators.e_full(**sim_args).tocsr()
|
||||
# x = magma_fdfd.solve_A(A, b)
|
||||
|
||||
# x = csr.cg_solver(J=vec(J), **sim_args)
|
||||
x = cg_solver(J=vec(J), **sim_args)
|
||||
print('Norm of the residual is ', norm(A @ x - b))
|
||||
|
||||
E = unvec(x, grid.shape)
|
||||
|
||||
print('Norm of the residual is ', numpy.linalg.norm(A @ x - b))
|
||||
|
||||
'''
|
||||
Plot results
|
||||
'''
|
||||
|
|
@ -197,6 +204,22 @@ def test1():
|
|||
pyplot.show()
|
||||
print('Average overlap with mode:', sum(q)/len(q))
|
||||
|
||||
|
||||
def module_available(name):
|
||||
return importlib.util.find_spec(name) is not None
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# test0()
|
||||
test1()
|
||||
|
||||
if module_available('opencl_fdfd'):
|
||||
from opencl_fdfd import cg_solver as opencl_solver
|
||||
test1(opencl_solver)
|
||||
# from opencl_fdfd.csr import fdfd_cg_solver as opencl_csr_solver
|
||||
# test1(opencl_csr_solver)
|
||||
# elif module_available('magma_fdfd'):
|
||||
# from magma_fdfd import solver as magma_solver
|
||||
# test1(magma_solver)
|
||||
else:
|
||||
test1()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue