Add _create_field() and _create_eps()
This commit is contained in:
parent
1e874cb0c0
commit
2b1d906665
@ -273,7 +273,24 @@ class Simulation(object):
|
|||||||
arguments=', '.join(S_args.keys()))
|
arguments=', '.join(S_args.keys()))
|
||||||
|
|
||||||
self.update_S = lambda e: S_update(*S_args.values(), wait_for=e)
|
self.update_S = lambda e: S_update(*S_args.values(), wait_for=e)
|
||||||
|
def _create_eps(self, epsilon: List[numpy.ndarray]):
|
||||||
|
if len(epsilon) != 3:
|
||||||
|
raise Exception('Epsilon must be a list with length of 3')
|
||||||
|
if not all((e.shape == epsilon[0].shape for e in epsilon[1:])):
|
||||||
|
raise Exception('All epsilon grids must have the same shape. Shapes are {}', [e.shape for e in epsilon])
|
||||||
|
if not epsilon[0].shape == self.shape:
|
||||||
|
raise Exception('Epsilon shape mismatch. Expected {}, got {}'.format(self.shape, epsilon[0].shape))
|
||||||
|
self.eps = pyopencl.array.to_device(self.queue, vec(epsilon).astype(self.arg_type))
|
||||||
|
|
||||||
|
def _create_field(self, initial_value: List[numpy.ndarray] = None):
|
||||||
|
if initial_value is None:
|
||||||
|
return pyopencl.array.zeros_like(self.eps)
|
||||||
|
else:
|
||||||
|
if len(initial_value) != 3:
|
||||||
|
Exception('Initial field value must be a list of length 3')
|
||||||
|
if not all((f.shape == self.shape for f in initial_value)):
|
||||||
|
Exception('Initial field list elements must have same shape as epsilon elements')
|
||||||
|
return pyopencl.array.to_device(self.queue, vec(initial_value).astype(self.arg_type))
|
||||||
|
|
||||||
def type_to_C(float_type) -> str:
|
def type_to_C(float_type) -> str:
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user