I am trying to implement the Complex and Real Schur Decompositions using only the Keras Ops API, so I can utilize it to impose some eigenvalue constraints on a weight matrix. I am also using the JAX backend, which further limits what I can actually do with the input matrix:
- All operations are done using single-precision floating-point operations, and no in-place modifications are allowed.
- The stopping criterion has to be static (e.g., an argument of the decomposition function, as I am currently doing) or a function of the input shape (similar to the max-iteration heuristic used by the LAPACK implementation), as the execution graph is pre-compiled using a traced array.
- This also means that conversion to a numpy ndarray is not allowed.
- I could, in theory, use JAX-specific functions, since the Keras tensors are just JAX arrays in disguise. However, I want to keep the code backend-agnostic if possible.
- It is worth noting that JAX currently has a CPU-only implementation of the Schur algorithm; however, my code requires GPU execution, so I cannot use it.
With that being said, I've already made some progress using these lecture notes from ETH Zürich. More specifically, I got algorithm 4.1 (Basic QR algorithm) working (variable-naming conventions from Scipy):
from math import pi, inf
from typing import Literal
from keras import KerasTensor as Tensor
from keras import ops
type OutputType = Literal["real", "complex"]
def quasi_triu(a: Tensor) -> Tensor:
mod_idxs = ops.stack([
ops.arange(1, len(a), 2),
ops.arange(1, len(a), 2) - 1
], axis=1)
return ops.scatter_update(ops.triu(a), mod_idxs, a[*mod_idxs.T])
def schur(a: Tensor, output: OutputType,
num_iters: int) -> tuple[Tensor, Tensor]:
a0, a1, z, cnt = None, a, ops.eye(a.shape[0]), 0
while cnt < num_iters:
a0, cnt = a1, cnt + 1
q, r = ops.qr(a0)
a1 = r @ q
z = z @ q
if output == "complex" or str(a.dtype).startswith("complex"):
return ops.triu(a1), z
else:
return quasi_triu(a1), z
However, as expected, the convergence rate is so low that I have to choose between preemptively setting an absurd number of iterations or risking instability due to non-convergence. Thus, I'm now trying to get algorithm 4.4 (Hessenberg QR algorithm with Rayleigh quotient shift) to work. So far, I have implemented the reduction to the Hessenberg matrix through Householder transformations following the theory in Wikipedia, with some modifications to preserve compatibility with the Scipy version ($A = QHQ^*$ instead of $H = QAQ^*$):
def hessenberg(a: Tensor) -> tuple[Tensor, Tensor]:
a = ops.convert_to_tensor(a)
n = a.shape[0]
u = ops.eye(n)
for k in range(n - 2):
aux_coeff = ops.where(
a[k + 1, k] == 0,
-1,
a[k + 1, k] / ops.abs(a[k + 1, k])
)
w = ops.norm(
a[k + 1:, k], 2
) * ops.eye(n - (k + 1), 1) + aux_coeff * a[k + 1:, k:k + 1]
v_householder = ops.eye(
n - (k + 1)
) - 2 * w @ ops.conj(ops.transpose(w)) / ops.norm(w, 2)**2
u_k = ops.slice_update(
ops.eye(n, dtype=v_householder.dtype),
(k + 1, k + 1),
v_householder
)
a = conj_transpose(u_k) @ a @ u_k # Necessary due to single precision
u = u @ u_k
assert np.allclose(np.eye(n), conj_transpose(u) @ u, atol=1e-6)
return a, u
However, I haven't been able to get the actual algorithm working. So far, I have the following function:
def schur_hessenberg(a: Tensor, output: OutputType,
num_iters: int) -> tuple[Tensor, Tensor]:
n = a.shape[0]
h, q = hessenberg(a)
z = ops.eye(n)
for m in range(n - 1, 0, -1):
for _ in range(num_iters):
sigma = h[m, m]
q, r = ops.qr(h - sigma * ops.eye(n))
h = r @ q + sigma * ops.eye(n)
z = z @ q
if output == "complex" or str(a.dtype).startswith("complex"):
return ops.triu(h), q @ z
else:
return quasi_triu(h), q @ z
However, when I test both methods using a random real matrix (similar results for the complex case) and measure the matrix reconstruction absolute error as
test_matrix = np.random.random((5, 5))
t0, z0 = schur_hessenberg(test_matrix, "real", 100000)
t1, z1 = schur(test_matrix, "real", 100000)
print(np.abs(test_matrix - z0 @ t0 @ conj_transpose(z0)))
print(np.abs(test_matrix - z1 @ t1 @ conj_transpose(z1)))
I get the following results:
[[0.02052921 0.62834513 1.0401626 0.08546968 0.50802547]
[1.2145011 0.47686994 0.14543009 0.77371585 0.44549757]
[1.7762737 0.04462999 0.84596455 0.12113664 0.42194515]
[0.6079759 0.8015515 0.0135076 0.07646421 0.00241762]
[0.12112255 0.72910243 0.5631428 0.03763434 1.2257423 ]]
[[6.2525272e-05 4.6223402e-05 5.9843063e-05 6.6056848e-05 1.8805265e-05]
[2.8496981e-04 1.1801720e-05 1.1438131e-04 3.2186508e-06 2.9633939e-04]
[8.8214874e-06 2.5629997e-05 1.8000603e-05 3.6358833e-05 2.2232533e-05]
[1.5932322e-04 3.5166740e-05 2.1210406e-05 9.1180205e-05 1.7516315e-04]
[2.7435273e-04 5.5968761e-05 3.5285950e-05 8.4638596e-06 1.6790628e-04]]
Clearly, the decomposition is not properly implemented. I already checked the Hessenberg decomposition reconstruction error and, numerical precision aside, it seems to be working. Furthermore, the $H$ matrix generated by the Hessenberg QR iterations seems to be (at least) quasi-upper triangular. Thus, I have either a) incorrectly implemented the Hessenberg QR iterations, or b) failed to properly combine the Hessenberg and Schur factors.
I've already gone over the code and reference material several times, but I cannot spot where I'm screwing up, so I'd really appreciate any pointers.