From d12ce6c957e93334838203aeb620690154e03f9b Mon Sep 17 00:00:00 2001 From: jan Date: Mon, 4 Jul 2016 19:30:36 -0700 Subject: [PATCH] Include old csr solver --- opencl_fdfd/csr.py | 195 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 195 insertions(+) create mode 100644 opencl_fdfd/csr.py diff --git a/opencl_fdfd/csr.py b/opencl_fdfd/csr.py new file mode 100644 index 0000000..ed2a338 --- /dev/null +++ b/opencl_fdfd/csr.py @@ -0,0 +1,195 @@ +import numpy + +import pyopencl +import pyopencl.array +from pyopencl.elementwise import ElementwiseKernel +from pyopencl.reduction import ReductionKernel + +import time + + +def type_to_C(float_type: numpy.float32 or numpy.float64) -> str: + """ + Returns a string corresponding to the C equivalent of a numpy type. + + :param float_type: numpy type: float32, float64, complex64, complex128 + :return: string containing the corresponding C type (eg. 'double') + """ + types = { + numpy.float32: 'float', + numpy.float64: 'double', + numpy.complex64: 'cfloat_t', + numpy.complex128: 'cdouble_t', + } + if float_type not in types: + raise Exception('Unsupported type') + + return types[float_type] + + +def create_ops(context): + preamble = ''' + #define PYOPENCL_DEFINE_CDOUBLE + #include + ''' + + ctype = type_to_C(numpy.complex128) + + # ------------------------------------- + + spmv_source = ''' + int start = m_row_ptr[i]; + int stop = m_row_ptr[i+1]; + cdouble_t dot = cdouble_new(0.0, 0.0); + + int col_ind, d_ind; + for (int j=start; j