aerial_image/aerial_image/no_gpu.py
2022-01-06 19:49:14 -08:00

82 lines
2.4 KiB
Python

import numpy
from numpy import pi
from numpy.fft import fftshift, ifftshift, fft2, ifft2, fftfreq
__author__ = 'Jan Petykiewicz'
def optics_transfer(kx: numpy.ndarray,
ky: numpy.ndarray,
dz: float = 200,
wavelength: float = 193,
num_aperture: float = 0.95,
magnification: float = 1
) -> numpy.ndarray:
k2 = kx * kx + ky * ky
g = num_aperture / (wavelength * magnification)
rect = k2 < g * g
return numpy.exp(1j * pi * dz * wavelength * k2) * rect
def source_distribution(kx: numpy.ndarray,
ky: numpy.ndarray,
wavelength: float = 193,
num_aperture: float = 0.95
) -> numpy.ndarray:
g = num_aperture / wavelength
k2 = kx * kx + ky * ky
return numpy.array((g / 2 < abs(kx)) * (g / 2 < abs(ky)) * (k2 < g * g), dtype=numpy.float32)
# return numpy.array((k2 < g * g), dtype=numpy.float32)
def aerial_image(mask: numpy.ndarray):
padded_shape = tuple(1 << int(numpy.ceil(numpy.log2(g))) for g in mask.shape)
mask_k = fftshift(fft2(mask, padded_shape))
kxy = tuple(fftshift(fftfreq(g)) for g in padded_shape)
k_grid = numpy.meshgrid(*kxy, indexing='ij')
optics_k = optics_transfer(*k_grid)
src_k = source_distribution(*k_grid)
# from matplotlib import pyplot
# pyplot.pcolormesh(src_k)
# pyplot.show()
out = numpy.zeros(shape=src_k.shape, dtype=numpy.float32)
nz = src_k.nonzero()
for kxi, kyi, src_v in zip(*nz, src_k[nz]):
dxi, dyi = (kxi, kyi) - numpy.array(src_k.shape) // 2
rolled = numpy.roll(numpy.roll(mask_k, dxi, axis=0), dyi, axis=1)
z = ifft2(ifftshift(rolled * optics_k))
out += src_v * (z.real * z.real + z.imag * z.imag)
return out
if __name__ == '__main__':
from matplotlib import pyplot
wl0, wl1 = 600, 200
n = 2048
x, y = numpy.meshgrid(numpy.arange(n), numpy.arange(n), indexing='ij')
a = (1/wl1 - 1/wl0) / n
chirped = numpy.sin(2 * pi * x * (a * x + 1/wl0)) * numpy.sin(2 * pi * y * (a * y + 1/wl0))
mask = chirped > 0.5
fig = pyplot.figure()
ax = fig.add_subplot(1, 1, 1)
ax.pcolormesh(mask)
exp = aerial_image(mask)
fig = pyplot.figure()
ax = fig.add_subplot(1, 1, 1)
ax.pcolormesh(exp)
pyplot.show()