Computes a low-rank approximation to the Cholesky decomposition.
tfp.substrates.jax.math.low_rank_cholesky(
matrix, max_rank, trace_atol=0, trace_rtol=0, name=None
)
This routine is similar to pivoted_cholesky, but works under JAX, at
the cost of being slightly less numerically stable.
Args |
matrix
|
Floating point Tensor batch of symmetric, positive definite
matrices, or a tf.linalg.LinearOperator.
|
max_rank
|
Scalar int Tensor , the rank at which to truncate the
approximation.
|
trace_atol
|
Scalar floating point Tensor (same dtype as matrix ). If
trace_atol > 0 and trace(matrix - LR * LR^t) < trace_atol, the output
LR matrix is allowed to be of rank less than max_rank.
|
trace_rtol
|
Scalar floating point Tensor (same dtype as matrix ). If
trace_rtol > 0 and trace(matrix - LR * LR^t) < trace_rtol * trace(matrix),
the output LR matrix is allowed to be of rank less than max_rank.
|
name
|
Optional name for the op.
|
Returns |
A triplet (LR, r, residual_diag) of
|
LR
|
a matrix such that LR * LR^t is approximately the input matrix.
If matrix is of shape (b1, ..., bn, m, m), then LR will be of shape
(b1, ..., bn, m, r) where r <= max_rank.
|
r
|
the rank of LR. If r is < max_rank, then
trace(matrix - LR * LR^t) < trace_atol, and
|
residual_diag
|
The diagonal entries of matrix - LR * LR^t. This is
returned because together with LR, it is useful for preconditioning
the input matrix.
|