diff --git a/fdfd_tools/bloch.py b/fdfd_tools/bloch.py index 1816bdc..fd5f758 100644 --- a/fdfd_tools/bloch.py +++ b/fdfd_tools/bloch.py @@ -77,7 +77,7 @@ from typing import Tuple, Callable import logging import numpy from numpy import pi, real, trace -from numpy.fft import fftn, ifftn, fftfreq +from numpy.fft import fftfreq import scipy import scipy.optimize from scipy.linalg import norm @@ -88,6 +88,29 @@ from . import field_t logger = logging.getLogger(__name__) +try: + import pyfftw.interfaces.numpy_fft + import pyfftw.interfaces + import multiprocessing + + pyfftw.interfaces.cache.enable() + pyfftw.interfaces.cache.set_keepalive_time(3600) + fftw_args = { + 'threads': multiprocessing.cpu_count(), + 'overwrite_input': True, + 'planner_effort': 'FFTW_EXHAUSTIVE', + } + + def fftn(*args, **kwargs): + return pyfftw.interfaces.numpy_fft.fftn(*args, **kwargs, **fftw_args) + + def ifftn(*args, **kwargs): + return pyfftw.interfaces.numpy_fft.ifftn(*args, **kwargs, **fftw_args) + +except ImportError: + from numpy.fft import fftn, ifftn + + def generate_kmn(k0: numpy.ndarray, G_matrix: numpy.ndarray, shape: numpy.ndarray