forked from jan/opencl_fdtd
		
	Add _create_context(), _create_operation(), and _create_pmls(), and generalize initial field value args
This commit is contained in:
		
							parent
							
								
									2b1d906665
								
							
						
					
					
						commit
						f00c8b4a3e
					
				@ -76,8 +76,7 @@ class Simulation(object):
 | 
				
			|||||||
                 epsilon: List[numpy.ndarray],
 | 
					                 epsilon: List[numpy.ndarray],
 | 
				
			||||||
                 pmls: List[Dict[str, int or float]],
 | 
					                 pmls: List[Dict[str, int or float]],
 | 
				
			||||||
                 dt: float = .99/numpy.sqrt(3),
 | 
					                 dt: float = .99/numpy.sqrt(3),
 | 
				
			||||||
                 initial_E: List[numpy.ndarray] = None,
 | 
					                 initial_fields: Dict[str, List[numpy.ndarray]] = None,
 | 
				
			||||||
                 initial_H: List[numpy.ndarray] = None,
 | 
					 | 
				
			||||||
                 context: pyopencl.Context = None,
 | 
					                 context: pyopencl.Context = None,
 | 
				
			||||||
                 queue: pyopencl.CommandQueue = None,
 | 
					                 queue: pyopencl.CommandQueue = None,
 | 
				
			||||||
                 float_type: numpy.float32 or numpy.float64 = numpy.float32,
 | 
					                 float_type: numpy.float32 or numpy.float64 = numpy.float32,
 | 
				
			||||||
@ -113,21 +112,14 @@ class Simulation(object):
 | 
				
			|||||||
                    * GPU memory requirements are approximately doubled, since S and the intermediate
 | 
					                    * GPU memory requirements are approximately doubled, since S and the intermediate
 | 
				
			||||||
                        products must be stored.
 | 
					                        products must be stored.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
 | 
					        if initial_fields is None:
 | 
				
			||||||
 | 
					            initial_fields = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if len(epsilon) != 3:
 | 
					        self.shape = epsilon[0].shape
 | 
				
			||||||
            Exception('Epsilon must be a list with length of 3')
 | 
					        self.arg_type = float_type
 | 
				
			||||||
        if not all((e.shape == epsilon[0].shape for e in epsilon[1:])):
 | 
					        self.sources = {}
 | 
				
			||||||
            Exception('All epsilon grids must have the same shape. Shapes are {}', [e.shape for e in epsilon])
 | 
					        self._create_context(context, queue)
 | 
				
			||||||
 | 
					        self._create_eps(epsilon)
 | 
				
			||||||
        if context is None:
 | 
					 | 
				
			||||||
            self.context = pyopencl.create_some_context()
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            self.context = context
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if queue is None:
 | 
					 | 
				
			||||||
            self.queue = pyopencl.CommandQueue(self.context)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            self.queue = queue
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if dt > .99/numpy.sqrt(3):
 | 
					        if dt > .99/numpy.sqrt(3):
 | 
				
			||||||
            warnings.warn('Warning: unstable dt: {}'.format(dt))
 | 
					            warnings.warn('Warning: unstable dt: {}'.format(dt))
 | 
				
			||||||
@ -136,27 +128,8 @@ class Simulation(object):
 | 
				
			|||||||
        else:
 | 
					        else:
 | 
				
			||||||
            self.dt = dt
 | 
					            self.dt = dt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.arg_type = float_type
 | 
					        self.E = self._create_field(initial_fields.get('E', None))
 | 
				
			||||||
        self.sources = {}
 | 
					        self.H = self._create_field(initial_fields.get('H', None))
 | 
				
			||||||
        self.eps = pyopencl.array.to_device(self.queue, vec(epsilon).astype(float_type))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if initial_E is None:
 | 
					 | 
				
			||||||
            self.E = pyopencl.array.zeros_like(self.eps)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            if len(initial_E) != 3:
 | 
					 | 
				
			||||||
                Exception('Initial_E must be a list of length 3')
 | 
					 | 
				
			||||||
            if not all((E.shape == epsilon[0].shape for E in initial_E)):
 | 
					 | 
				
			||||||
                Exception('Initial_E list elements must have same shape as epsilon elements')
 | 
					 | 
				
			||||||
            self.E = pyopencl.array.to_device(self.queue, vec(E).astype(float_type))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if initial_H is None:
 | 
					 | 
				
			||||||
            self.H = pyopencl.array.zeros_like(self.eps)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            if len(initial_H) != 3:
 | 
					 | 
				
			||||||
                Exception('Initial_H must be a list of length 3')
 | 
					 | 
				
			||||||
            if not all((H.shape == epsilon[0].shape for H in initial_H)):
 | 
					 | 
				
			||||||
                Exception('Initial_H list elements must have same shape as epsilon elements')
 | 
					 | 
				
			||||||
            self.H = pyopencl.array.to_device(self.queue, vec(H).astype(float_type))
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for pml in pmls:
 | 
					        for pml in pmls:
 | 
				
			||||||
            pml.setdefault('thickness', 8)
 | 
					            pml.setdefault('thickness', 8)
 | 
				
			||||||
@ -181,7 +154,7 @@ class Simulation(object):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        common_source = jinja_env.get_template('common.cl').render(
 | 
					        common_source = jinja_env.get_template('common.cl').render(
 | 
				
			||||||
                ftype=ctype,
 | 
					                ftype=ctype,
 | 
				
			||||||
                shape=epsilon[0].shape,
 | 
					                shape=self.shape,
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
        jinja_args = {
 | 
					        jinja_args = {
 | 
				
			||||||
                'common_header': common_source,
 | 
					                'common_header': common_source,
 | 
				
			||||||
@ -194,21 +167,37 @@ class Simulation(object):
 | 
				
			|||||||
        self.sources['E'] = E_source
 | 
					        self.sources['E'] = E_source
 | 
				
			||||||
        self.sources['H'] = H_source
 | 
					        self.sources['H'] = H_source
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        S_fields = OrderedDict()
 | 
				
			||||||
        if do_poynting:
 | 
					        if do_poynting:
 | 
				
			||||||
            S_source = jinja_env.get_template('update_s.cl').render(**jinja_args)
 | 
					            S_source = jinja_env.get_template('update_s.cl').render(**jinja_args)
 | 
				
			||||||
            self.sources['S'] = S_source
 | 
					            self.sources['S'] = S_source
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            self.oS = pyopencl.array.zeros(self.queue, self.E.shape + (2,), dtype=float_type)
 | 
					            self.oS = pyopencl.array.zeros(self.queue, self.E.shape + (2,), dtype=self.arg_type)
 | 
				
			||||||
            self.S = pyopencl.array.zeros_like(self.E)
 | 
					            self.S = pyopencl.array.zeros_like(self.E)
 | 
				
			||||||
            S_fields = OrderedDict()
 | 
					 | 
				
			||||||
            S_fields[ptr('oS')] = self.oS
 | 
					            S_fields[ptr('oS')] = self.oS
 | 
				
			||||||
            S_fields[ptr('S')] = self.S
 | 
					            S_fields[ptr('S')] = self.S
 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            S_fields = OrderedDict()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        '''
 | 
					        '''
 | 
				
			||||||
        PML
 | 
					        PML
 | 
				
			||||||
        '''
 | 
					        '''
 | 
				
			||||||
 | 
					        pml_e_fields, pml_h_fields = self._create_pmls(pmls)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        '''
 | 
				
			||||||
 | 
					        Create operations
 | 
				
			||||||
 | 
					        '''
 | 
				
			||||||
 | 
					        self.update_E = self._create_operation(E_source, (base_fields, eps_field, pml_e_fields))
 | 
				
			||||||
 | 
					        self.update_H = self._create_operation(H_source, (base_fields, pml_h_fields, S_fields))
 | 
				
			||||||
 | 
					        if do_poynting:
 | 
				
			||||||
 | 
					            self.update_S = self._create_operation(S_source, (base_fields, S_fields))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _create_pmls(self, pmls):
 | 
				
			||||||
 | 
					        ctype = type_to_C(self.arg_type)
 | 
				
			||||||
 | 
					        def ptr(arg: str) -> str:
 | 
				
			||||||
 | 
					            return ctype + ' *' + arg
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        pml_e_fields = OrderedDict()
 | 
					        pml_e_fields = OrderedDict()
 | 
				
			||||||
        pml_h_fields = OrderedDict()
 | 
					        pml_h_fields = OrderedDict()
 | 
				
			||||||
        for pml in pmls:
 | 
					        for pml in pmls:
 | 
				
			||||||
@ -225,7 +214,7 @@ class Simulation(object):
 | 
				
			|||||||
                p1 = sigma / (sigma + alpha) * (p0 - 1)
 | 
					                p1 = sigma / (sigma + alpha) * (p0 - 1)
 | 
				
			||||||
                return p0, p1
 | 
					                return p0, p1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            xe, xh = (numpy.arange(1, pml['thickness'] + 1, dtype=float_type)[::-1] for _ in range(2))
 | 
					            xe, xh = (numpy.arange(1, pml['thickness'] + 1, dtype=self.arg_type)[::-1] for _ in range(2))
 | 
				
			||||||
            if pml['polarity'] == 'p':
 | 
					            if pml['polarity'] == 'p':
 | 
				
			||||||
                xe -= 0.5
 | 
					                xe -= 0.5
 | 
				
			||||||
            elif pml['polarity'] == 'n':
 | 
					            elif pml['polarity'] == 'n':
 | 
				
			||||||
@ -240,39 +229,34 @@ class Simulation(object):
 | 
				
			|||||||
            psi_base = 'Psi_' + pml['axis'] + pml['polarity'] + '_'
 | 
					            psi_base = 'Psi_' + pml['axis'] + pml['polarity'] + '_'
 | 
				
			||||||
            psi_names = [[psi_base + eh + c for c in uv] for eh in 'EH']
 | 
					            psi_names = [[psi_base + eh + c for c in uv] for eh in 'EH']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            psi_shape = list(epsilon[0].shape)
 | 
					            psi_shape = list(self.shape)
 | 
				
			||||||
            psi_shape[a] = pml['thickness']
 | 
					            psi_shape[a] = pml['thickness']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            for ne, nh in zip(*psi_names):
 | 
					            for ne, nh in zip(*psi_names):
 | 
				
			||||||
                pml_e_fields[ptr(ne)] = pyopencl.array.zeros(self.queue, tuple(psi_shape), dtype=self.arg_type)
 | 
					                pml_e_fields[ptr(ne)] = pyopencl.array.zeros(self.queue, tuple(psi_shape), dtype=self.arg_type)
 | 
				
			||||||
                pml_h_fields[ptr(nh)] = pyopencl.array.zeros(self.queue, tuple(psi_shape), dtype=self.arg_type)
 | 
					                pml_h_fields[ptr(nh)] = pyopencl.array.zeros(self.queue, tuple(psi_shape), dtype=self.arg_type)
 | 
				
			||||||
 | 
					        return pml_e_fields, pml_h_fields
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.pml_e_fields = pml_e_fields
 | 
					    def _create_operation(self, source, args_fields):
 | 
				
			||||||
        self.pml_h_fields = pml_h_fields
 | 
					        args = OrderedDict()
 | 
				
			||||||
 | 
					        [args.update(d) for d in args_fields]
 | 
				
			||||||
 | 
					        update = ElementwiseKernel(self.context, operation=source,
 | 
				
			||||||
 | 
					                                   arguments=', '.join(args.keys()))
 | 
				
			||||||
 | 
					        return lambda e: update(*args.values(), wait_for=e)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        '''
 | 
					    def _create_context(self, context: pyopencl.Context = None,
 | 
				
			||||||
        Create operations
 | 
					                        queue: pyopencl.CommandQueue = None):
 | 
				
			||||||
        '''
 | 
					        if context is None:
 | 
				
			||||||
        E_args = OrderedDict()
 | 
					            self.context = pyopencl.create_some_context()
 | 
				
			||||||
        [E_args.update(d) for d in (base_fields, eps_field, pml_e_fields)]
 | 
					        else:
 | 
				
			||||||
        E_update = ElementwiseKernel(self.context, operation=E_source,
 | 
					            self.context = context
 | 
				
			||||||
                                     arguments=', '.join(E_args.keys()))
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        H_args = OrderedDict()
 | 
					        if queue is None:
 | 
				
			||||||
        [H_args.update(d) for d in (base_fields, pml_h_fields, S_fields)]
 | 
					            self.queue = pyopencl.CommandQueue(self.context)
 | 
				
			||||||
        H_update = ElementwiseKernel(self.context, operation=H_source,
 | 
					        else:
 | 
				
			||||||
                                     arguments=', '.join(H_args.keys()))
 | 
					            self.queue = queue
 | 
				
			||||||
        self.update_E = lambda e: E_update(*E_args.values(), wait_for=e)
 | 
					 | 
				
			||||||
        self.update_H = lambda e: H_update(*H_args.values(), wait_for=e)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if do_poynting:
 | 
					 | 
				
			||||||
            S_args = OrderedDict()
 | 
					 | 
				
			||||||
            [S_args.update(d) for d in (base_fields, S_fields)]
 | 
					 | 
				
			||||||
            S_update = ElementwiseKernel(self.context, operation=S_source,
 | 
					 | 
				
			||||||
                                         arguments=', '.join(S_args.keys()))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            self.update_S = lambda e: S_update(*S_args.values(), wait_for=e)
 | 
					 | 
				
			||||||
    def _create_eps(self, epsilon: List[numpy.ndarray]):
 | 
					    def _create_eps(self, epsilon: List[numpy.ndarray]):
 | 
				
			||||||
        if len(epsilon) != 3:
 | 
					        if len(epsilon) != 3:
 | 
				
			||||||
            raise Exception('Epsilon must be a list with length of 3')
 | 
					            raise Exception('Epsilon must be a list with length of 3')
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user