diff --git a/meanas/fdfd/bloch.py b/meanas/fdfd/bloch.py
index 2e1da30..71b2a8b 100644
--- a/meanas/fdfd/bloch.py
+++ b/meanas/fdfd/bloch.py
@@ -799,3 +799,52 @@ def _rtrace_AtB(
def _symmetrize(A: NDArray[numpy.complex128]) -> NDArray[numpy.complex128]:
return (A + A.conj().T) * 0.5
+
+
+def inner_product(eL, hL, eR, hR) -> complex:
+ # assumes x-axis propagation
+
+ assert numpy.array_equal(eR.shape, hR.shape)
+ assert numpy.array_equal(eL.shape, hL.shape)
+ assert numpy.array_equal(eR.shape, eL.shape)
+
+ # Cross product, times 2 since it's
, then divide by 4. # TODO might want to abs() this?
+ norm2R = (eR[1] * hR[2] - eR[2] * hR[1]).sum() / 2
+ norm2L = (eL[1] * hL[2] - eL[2] * hL[1]).sum() / 2
+
+ # eRxhR_x = numpy.cross(eR.reshape(3, -1), hR.reshape(3, -1), axis=0).reshape(eR.shape)[0] / normR
+ # logger.info(f'power {eRxhR_x.sum() / 2})
+
+ eR /= numpy.sqrt(norm2R)
+ hR /= numpy.sqrt(norm2R)
+ eL /= numpy.sqrt(norm2L)
+ hL /= numpy.sqrt(norm2L)
+
+ # (eR x hL)[0] and (eL x hR)[0]
+ eRxhL_x = eR[1] * hL[2] - eR[2] - hL[1]
+ eLxhR_x = eL[1] * hR[2] - eL[2] - hR[1]
+
+ #return 1j * (eRxhL_x - eLxhR_x).sum() / numpy.sqrt(norm2R * norm2L)
+ #return (eRxhL_x.sum() - eLxhR_x.sum()) / numpy.sqrt(norm2R * norm2L)
+ return eRxhL_x.sum() - eLxhR_x.sum()
+
+
+def trq(eI, hI, eO, hO) -> tuple[complex, complex]:
+ pp = inner_product(eO, hO, eI, hI)
+ pn = inner_product(eO, hO, eI, -hI)
+ np = inner_product(eO, -hO, eI, hI)
+ nn = inner_product(eO, -hO, eI, -hI)
+
+ assert pp == -nn
+ assert pn == -np
+
+ logger.info(f'''
+ {pp=:4g} {pn=:4g}
+ {nn=:4g} {np=:4g}
+ {nn * pp / pn=:4g} {-np=:4g}
+ ''')
+
+ r = -pp / pn # -/ = -(-pp) / (-pn)
+ t = (np - nn * pp / pn) / 4
+
+ return t, r
diff --git a/meanas/fdfd/eme.py b/meanas/fdfd/eme.py
new file mode 100644
index 0000000..35e1e90
--- /dev/null
+++ b/meanas/fdfd/eme.py
@@ -0,0 +1,68 @@
+import numpy
+
+from ..fdmath import vec, unvec, dx_lists_t, vfdfield_t, vcfdfield_t
+from .waveguide_2d import inner_product
+
+
+def get_tr(ehL, wavenumbers_L, ehR, wavenumbers_R, dxes: dx_lists_t):
+ nL = len(wavenumbers_L)
+ nR = len(wavenumbers_R)
+ A12 = numpy.zeros((nL, nR), dtype=complex)
+ A21 = numpy.zeros((nL, nR), dtype=complex)
+ B11 = numpy.zeros((nL,), dtype=complex)
+ for ll in range(nL):
+ eL, hL = ehL[ll]
+ B11[ll] = inner_product(eL, hL, dxes=dxes, conj_h=False)
+ for rr in range(nR):
+ eR, hR = ehR[rr]
+ A12[ll, rr] = inner_product(eL, hR, dxes=dxes, conj_h=False) # TODO optimize loop?
+ A21[ll, rr] = inner_product(eR, hL, dxes=dxes, conj_h=False)
+
+ # tt0 = 2 * numpy.linalg.pinv(A21 + numpy.conj(A12))
+ tt0, _resid, _rank, _sing = numpy.linalg.lstsq(A21 + A12, numpy.diag(2 * B11), rcond=None)
+
+ U, st, V = numpy.linalg.svd(tt0)
+ gain = st > 1
+ st[gain] = 1 / st[gain]
+ tt = U @ numpy.diag(st) @ V
+
+ # rr = 0.5 * (A21 - numpy.conj(A12)) @ tt
+ rr = numpy.diag(0.5 / B11) @ (A21 - A12) @ tt
+
+ return tt, rr
+
+
+def get_abcd(eL_xys, wavenumbers_L, eR_xys, wavenumbers_R, **kwargs):
+ t12, r12 = get_tr(eL_xys, wavenumbers_L, eR_xys, wavenumbers_R, **kwargs)
+ t21, r21 = get_tr(eR_xys, wavenumbers_R, eL_xys, wavenumbers_L, **kwargs)
+ t21i = numpy.linalg.pinv(t21)
+ A = t12 - r21 @ t21i @ r12
+ B = r21 @ t21i
+ C = -t21i @ r12
+ D = t21i
+ return sparse.block_array(((A, B), (C, D)))
+
+
+def get_s(
+ eL_xys,
+ wavenumbers_L,
+ eR_xys,
+ wavenumbers_R,
+ force_nogain: bool = False,
+ force_reciprocal: bool = False,
+ **kwargs):
+ t12, r12 = get_tr(eL_xys, wavenumbers_L, eR_xys, wavenumbers_R, **kwargs)
+ t21, r21 = get_tr(eR_xys, wavenumbers_R, eL_xys, wavenumbers_L, **kwargs)
+
+ ss = numpy.block([[r12, t12],
+ [t21, r21]])
+
+ if force_nogain:
+ # force S @ S.H diagonal
+ U, sing, V = numpy.linalg.svd(ss)
+ ss = numpy.diag(sing) @ U @ V
+
+ if force_reciprocal:
+ ss = 0.5 * (ss + ss.T)
+
+ return ss