You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

90 lines
2.7 KiB
Python

import numpy
from numpy import pi
from numpy.fft import fftshift, ifftshift, fft2, ifft2, fftfreq
import pyopencl
import pyopencl.array
from gpyfft.fft import FFT
__author__ = 'Jan Petykiewicz'
def optics_transfer(kx: numpy.ndarray,
ky: numpy.ndarray,
dz: float = 200,
wavelength: float = 193,
num_aperture: float = 0.95,
magnification: float = 1
) -> numpy.ndarray:
k2 = kx * kx + ky * ky
g = num_aperture / (wavelength * magnification)
rect = k2 < g * g
return numpy.exp(1j * pi * dz * wavelength * k2) * rect
def source_distribution(kx: numpy.ndarray,
ky: numpy.ndarray,
wavelength: float = 193,
num_aperture: float = 0.95
) -> numpy.ndarray:
g = num_aperture / wavelength
k2 = kx * kx + ky * ky
return numpy.array((g / 2 < abs(kx)) * (g / 2 < abs(ky)) * (k2 < g * g), dtype=numpy.float32)
def aerial_image(mask: numpy.ndarray):
padded_shape = tuple(1 << int(numpy.ceil(numpy.log2(g))) for g in mask.shape)
mask_k = fftshift(fft2(mask, padded_shape))
kxy = tuple(fftshift(fftfreq(g)) for g in padded_shape)
k_grid = numpy.meshgrid(*kxy, indexing='ij')
optics_k = optics_transfer(*k_grid)
src_k = source_distribution(*k_grid)
ctx = pyopencl.create_some_context()
queue = pyopencl.CommandQueue(ctx)
out = numpy.zeros(shape=src_k.shape, dtype=numpy.float32)
nz = src_k.nonzero()
#buf = pyopencl.array.empty(queue, shape=out.shape, dtype=numpy.complex64)
buf = pyopencl.array.to_device(queue, optics_k)
transform = FFT(ctx, queue, buf, axes=(0, 1))
for kxi, kyi, src_v in zip(*nz, src_k[nz]):
dxi, dyi = (kxi, kyi) - numpy.array(src_k.shape) // 2
rolled = numpy.roll(numpy.roll(mask_k, dxi, axis=0), dyi, axis=1)
buf.set(ifftshift(rolled * optics_k))
transform.enqueue(forward=False)[0].wait()
z = buf.get()
#z = ifft2(ifftshift(rolled * optics_k))
out += src_v * (z.real * z.real + z.imag * z.imag)
return out
if __name__ == '__main__':
from matplotlib import pyplot
wl0, wl1 = 600, 200
n = 2048
x, y = numpy.meshgrid(numpy.arange(n), numpy.arange(n), indexing='ij')
a = (1/wl1 - 1/wl0) / n
chirped = numpy.sin(2 * pi * x * (a * x + 1/wl0)) * numpy.sin(2 * pi * y * (a * y + 1/wl0))
mask = chirped > 0.5
fig = pyplot.figure()
ax = fig.add_subplot(1, 1, 1)
ax.pcolormesh(mask)
exp = aerial_image(mask)
fig = pyplot.figure()
ax = fig.add_subplot(1, 1, 1)
ax.pcolormesh(exp)
pyplot.show()