diff --git a/nom-eme.py b/nom-eme.py index b4e5dd0..2b89f8a 100644 --- a/nom-eme.py +++ b/nom-eme.py @@ -250,8 +250,27 @@ def connect_s( if k > A.shape[-1] - 1 or l > B.shape[-1] - 1: raise ValueError("port indices are out of range") - C = scipy.sparse.block_diag((A, B), dtype=complex) - return innerconnect_s(C, k, A.shape[0] + l) + #C = scipy.sparse.block_diag((A, B), dtype=complex) + #return innerconnect_s(C, k, A.shape[0] + l) + + nA = A.shape[-1] + nB = B.shape[-1] + nC = nA + nB - 2 + assert numpy.array_equal(A.shape[:-2], B.shape[:-2]) + + denom = 1 - A[..., k, k] * B[..., l, l] + Anew = A + A[..., k, :] * B[..., l, l] * A[..., :, k] / denom + Bnew = A[..., k, :] * B[..., :, l] / denom + Anew = npy.delete(Anew, (k,), 1) + Anew = npy.delete(Anew, (k,), 2) + Bnew = npy.delete(Bnew, (l,), 1) + Bnew = npy.delete(Bnew, (l,), 2) + + dtype = (A[0, 0] * B[0, 0]).dtype + C = numpy.zeros(tuple(A.shape[:-2]) + (nn, nn), dtype=dtype) + C[..., :nA - 1, :nA - 1] = Anew + C[..., nA - 1:, nA - 1:] = Bnew + return C def innerconnect_s(