forked from jan/fdfd_tools
		
	Use pyfftw if available
This commit is contained in:
		
							parent
							
								
									e8f836c908
								
							
						
					
					
						commit
						c1f65f61c1
					
				@ -77,7 +77,7 @@ from typing import Tuple, Callable
 | 
			
		||||
import logging
 | 
			
		||||
import numpy
 | 
			
		||||
from numpy import pi, real, trace
 | 
			
		||||
from numpy.fft import fftn, ifftn, fftfreq
 | 
			
		||||
from numpy.fft import fftfreq
 | 
			
		||||
import scipy
 | 
			
		||||
import scipy.optimize
 | 
			
		||||
from scipy.linalg import norm
 | 
			
		||||
@ -88,6 +88,29 @@ from . import field_t
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    import pyfftw.interfaces.numpy_fft
 | 
			
		||||
    import pyfftw.interfaces
 | 
			
		||||
    import multiprocessing
 | 
			
		||||
 | 
			
		||||
    pyfftw.interfaces.cache.enable()
 | 
			
		||||
    pyfftw.interfaces.cache.set_keepalive_time(3600)
 | 
			
		||||
    fftw_args = {
 | 
			
		||||
        'threads': multiprocessing.cpu_count(),
 | 
			
		||||
        'overwrite_input': True,
 | 
			
		||||
        'planner_effort': 'FFTW_EXHAUSTIVE',
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    def fftn(*args, **kwargs):
 | 
			
		||||
        return pyfftw.interfaces.numpy_fft.fftn(*args, **kwargs, **fftw_args)
 | 
			
		||||
 | 
			
		||||
    def ifftn(*args, **kwargs):
 | 
			
		||||
        return pyfftw.interfaces.numpy_fft.ifftn(*args, **kwargs, **fftw_args)
 | 
			
		||||
 | 
			
		||||
except ImportError:
 | 
			
		||||
    from numpy.fft import fftn, ifftn
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def generate_kmn(k0: numpy.ndarray,
 | 
			
		||||
                 G_matrix: numpy.ndarray,
 | 
			
		||||
                 shape: numpy.ndarray
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user