Use pyfftw if available

This commit is contained in:
jan 2018-01-15 22:43:59 -08:00
parent e8f836c908
commit c1f65f61c1

View File

@ -77,7 +77,7 @@ from typing import Tuple, Callable
import logging import logging
import numpy import numpy
from numpy import pi, real, trace from numpy import pi, real, trace
from numpy.fft import fftn, ifftn, fftfreq from numpy.fft import fftfreq
import scipy import scipy
import scipy.optimize import scipy.optimize
from scipy.linalg import norm from scipy.linalg import norm
@ -88,6 +88,29 @@ from . import field_t
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try:
import pyfftw.interfaces.numpy_fft
import pyfftw.interfaces
import multiprocessing
pyfftw.interfaces.cache.enable()
pyfftw.interfaces.cache.set_keepalive_time(3600)
fftw_args = {
'threads': multiprocessing.cpu_count(),
'overwrite_input': True,
'planner_effort': 'FFTW_EXHAUSTIVE',
}
def fftn(*args, **kwargs):
return pyfftw.interfaces.numpy_fft.fftn(*args, **kwargs, **fftw_args)
def ifftn(*args, **kwargs):
return pyfftw.interfaces.numpy_fft.ifftn(*args, **kwargs, **fftw_args)
except ImportError:
from numpy.fft import fftn, ifftn
def generate_kmn(k0: numpy.ndarray, def generate_kmn(k0: numpy.ndarray,
G_matrix: numpy.ndarray, G_matrix: numpy.ndarray,
shape: numpy.ndarray shape: numpy.ndarray