diff --git a/opencl_fdfd/ops.py b/opencl_fdfd/ops.py index 16d0d6b..80bf836 100644 --- a/opencl_fdfd/ops.py +++ b/opencl_fdfd/ops.py @@ -61,17 +61,17 @@ ctype = type_to_C(numpy.complex128) ctype_bare = 'cdouble' # Preamble for all OpenCL code -preamble = ''' +preamble = f''' #define PYOPENCL_DEFINE_CDOUBLE #include //Defines to clean up operation and type names -#define ctype {ctype}_t -#define zero {ctype}_new(0.0, 0.0) -#define add {ctype}_add -#define sub {ctype}_sub -#define mul {ctype}_mul -'''.format(ctype=ctype_bare) +#define ctype {ctype_bare}_t +#define zero {ctype_bare}_new(0.0, 0.0) +#define add {ctype_bare}_add +#define sub {ctype_bare}_sub +#define mul {ctype_bare}_mul +''' def ptrs(*args: str) -> list[str]: