0
$\begingroup$

I am trying to implement the Complex and Real Schur Decompositions using only the Keras Ops API, so I can utilize it to impose some eigenvalue constraints on a weight matrix. I am also using the JAX backend, which further limits what I can actually do with the input matrix:

  1. All operations are done using single-precision floating-point operations, and no in-place modifications are allowed.
  2. The stopping criterion has to be static (e.g., an argument of the decomposition function, as I am currently doing) or a function of the input shape (similar to the max-iteration heuristic used by the LAPACK implementation), as the execution graph is pre-compiled using a traced array.
    • This also means that conversion to a numpy ndarray is not allowed.
  3. I could, in theory, use JAX-specific functions, since the Keras tensors are just JAX arrays in disguise. However, I want to keep the code backend-agnostic if possible.
    • It is worth noting that JAX currently has a CPU-only implementation of the Schur algorithm; however, my code requires GPU execution, so I cannot use it.

With that being said, I've already made some progress using these lecture notes from ETH Zürich. More specifically, I got algorithm 4.1 (Basic QR algorithm) working (variable-naming conventions from Scipy):

from math import pi, inf

from typing import Literal

from keras import KerasTensor as Tensor
from keras import ops

type OutputType = Literal["real", "complex"]

def quasi_triu(a: Tensor) -> Tensor:
    mod_idxs = ops.stack([
        ops.arange(1, len(a), 2),
        ops.arange(1, len(a), 2) - 1
    ], axis=1)
    return ops.scatter_update(ops.triu(a), mod_idxs, a[*mod_idxs.T])

def schur(a: Tensor, output: OutputType,
          num_iters: int) -> tuple[Tensor, Tensor]:
    a0, a1, z, cnt = None, a, ops.eye(a.shape[0]), 0
    while cnt < num_iters:
        a0, cnt = a1, cnt + 1
        q, r = ops.qr(a0)
        a1 = r @ q
        z = z @ q

    if output == "complex" or str(a.dtype).startswith("complex"):
        return ops.triu(a1), z
    else:
        return quasi_triu(a1), z

However, as expected, the convergence rate is so low that I have to choose between preemptively setting an absurd number of iterations or risking instability due to non-convergence. Thus, I'm now trying to get algorithm 4.4 (Hessenberg QR algorithm with Rayleigh quotient shift) to work. So far, I have implemented the reduction to the Hessenberg matrix through Householder transformations following the theory in Wikipedia, with some modifications to preserve compatibility with the Scipy version ($A = QHQ^*$ instead of $H = QAQ^*$):

def hessenberg(a: Tensor) -> tuple[Tensor, Tensor]:
    a = ops.convert_to_tensor(a)
    n = a.shape[0]
    u = ops.eye(n)

    for k in range(n - 2):
        aux_coeff = ops.where(
            a[k + 1, k] == 0,
            -1,
            a[k + 1, k] / ops.abs(a[k + 1, k])
        )

        w = ops.norm(
            a[k + 1:, k], 2
        ) * ops.eye(n - (k + 1), 1) + aux_coeff * a[k + 1:, k:k + 1]

        v_householder = ops.eye(
            n - (k + 1)
        ) - 2 * w @ ops.conj(ops.transpose(w)) / ops.norm(w, 2)**2
        u_k = ops.slice_update(
            ops.eye(n, dtype=v_householder.dtype),
            (k + 1, k + 1),
            v_householder
        )

        a = conj_transpose(u_k) @ a @ u_k # Necessary due to single precision
        u = u @ u_k

    assert np.allclose(np.eye(n), conj_transpose(u) @ u, atol=1e-6)

    return a, u

However, I haven't been able to get the actual algorithm working. So far, I have the following function:

def schur_hessenberg(a: Tensor, output: OutputType,
                   num_iters: int) -> tuple[Tensor, Tensor]:
    n = a.shape[0]

    h, q = hessenberg(a)

    z = ops.eye(n)
    for m in range(n - 1, 0, -1):
        for _ in range(num_iters):
            sigma = h[m, m]
            q, r = ops.qr(h - sigma * ops.eye(n))
            h = r @ q + sigma * ops.eye(n)
            z = z @ q

    if output == "complex" or str(a.dtype).startswith("complex"):
        return ops.triu(h), q @ z
    else:
        return quasi_triu(h), q @ z

However, when I test both methods using a random real matrix (similar results for the complex case) and measure the matrix reconstruction absolute error as

test_matrix = np.random.random((5, 5))
t0, z0 = schur_hessenberg(test_matrix, "real", 100000)
t1, z1 = schur(test_matrix, "real", 100000)

print(np.abs(test_matrix - z0 @ t0 @ conj_transpose(z0)))
print(np.abs(test_matrix - z1 @ t1 @ conj_transpose(z1)))

I get the following results:

[[0.02052921 0.62834513 1.0401626  0.08546968 0.50802547]
 [1.2145011  0.47686994 0.14543009 0.77371585 0.44549757]
 [1.7762737  0.04462999 0.84596455 0.12113664 0.42194515]
 [0.6079759  0.8015515  0.0135076  0.07646421 0.00241762]
 [0.12112255 0.72910243 0.5631428  0.03763434 1.2257423 ]]

[[6.2525272e-05 4.6223402e-05 5.9843063e-05 6.6056848e-05 1.8805265e-05]
 [2.8496981e-04 1.1801720e-05 1.1438131e-04 3.2186508e-06 2.9633939e-04]
 [8.8214874e-06 2.5629997e-05 1.8000603e-05 3.6358833e-05 2.2232533e-05]
 [1.5932322e-04 3.5166740e-05 2.1210406e-05 9.1180205e-05 1.7516315e-04]
 [2.7435273e-04 5.5968761e-05 3.5285950e-05 8.4638596e-06 1.6790628e-04]]

Clearly, the decomposition is not properly implemented. I already checked the Hessenberg decomposition reconstruction error and, numerical precision aside, it seems to be working. Furthermore, the $H$ matrix generated by the Hessenberg QR iterations seems to be (at least) quasi-upper triangular. Thus, I have either a) incorrectly implemented the Hessenberg QR iterations, or b) failed to properly combine the Hessenberg and Schur factors.

I've already gone over the code and reference material several times, but I cannot spot where I'm screwing up, so I'd really appreciate any pointers.

$\endgroup$
4
  • 1
    $\begingroup$ Why restrict yourself to only using a framework that apparently isn't well suited to the task? Why not just convert whatever input you have to numpy, letting it do its work, and then converting back? $\endgroup$ Commented Oct 28 at 20:44
  • 1
    $\begingroup$ Because numpy doesn't run on CUDA. This is only a cog in a larger machine, and I need this decomposition to run once per batch over a set of matrices large enough so that the cost of moving them between devices is non-negligible. Besides, the hard threshold set by the LAPACK developers suggests that there is an upper bound to the number of QR iterations required for convergence, regardless of the contents. Even without all of the above, the error I'm getting is not due to the framework, but an actual mistake I made when implementing the algorithm. The only thing is: I don't know exactly where. $\endgroup$ Commented Oct 28 at 21:43
  • $\begingroup$ FWIW there's cupy which is basically numpy for GPUs and cupy can apparently handle tensors from JAX: docs.cupy.dev/en/stable/user_guide/… $\endgroup$ Commented Oct 29 at 6:48
  • $\begingroup$ @niemc, it also seems to accept PyTorch Tensors, so it could've been what I was looking for. Unfortunately, it seems like both the Hessenberg reduction and Schur decomposition have not been implemented in CuPy as of the latest stable version (13.6). Moreover, I would still like to understand what I'm doing wrong, even if I end up using an external implementation of the algorithm. $\endgroup$ Commented Oct 29 at 8:28

1 Answer 1

2
$\begingroup$

In the end, it was a variable shadowing bug: during the QR iterations, I was overwriting q when calling ops.qr.

Here is the fixed code:

def schur_hessenberg(a: Tensor, output: OutputType,
                     num_iters: int) -> tuple[Tensor, Tensor]:
    n = a.shape[0]

    # Before: h, q = hessenberg(a)
    h, q_hess = hessenberg(a)

    z = ops.eye(n)
    for m in range(n - 1, 0, -1):
        for _ in range(num_iters):
            sigma = h[m, m]
            q, r = ops.qr(h - sigma * ops.eye(n)) # Shadowing occurs here
            h = r @ q + sigma * ops.eye(n)
            z = z @ q

    if output == "complex" or str(a.dtype).startswith("complex"):
        return ops.triu(h), q_hess @ z
    else:
        return quasi_triu(h), q_hess @ z

I hope this implementation ends up being useful to someone else.

$\endgroup$

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.