"""Mathematical operations and utilities."""
# ruff: noqa: FBT001, FBT002, RUF002, RUF003, RET503
from typing import Any, Final
import torch
from torch.autograd import Function
@torch.jit.script
def torch_divmod(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute division and modulo operations for tensors.
Args:
a: Dividend tensor
b: Divisor tensor
Returns:
tuple containing:
- Quotient tensor
- Remainder tensor
"""
d = torch.div(a, b, rounding_mode="floor")
m = a % b
return d, m
"""Below code is taken from https://github.com/abhijeetgangan/torch_matfunc"""
[docs]
def expm_frechet( # noqa: C901
A: torch.Tensor,
E: torch.Tensor,
method: str | None = None,
compute_expm: bool = True,
check_finite: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Frechet derivative of the matrix exponential of A in the direction E.
Args:
A: (N, N) array_like. Matrix of which to take the matrix exponential.
E: (N, N) array_like. Matrix direction in which to take the Frechet derivative.
method: str, optional. Choice of algorithm. Should be one of
- `SPS` (default)
- `blockEnlarge`
compute_expm: bool, optional. Whether to compute also `expm_A` in addition to
`expm_frechet_AE`. Default is True.
check_finite: bool, optional. Whether to check that the input matrix contains
only finite numbers. Disabling may give a performance gain, but may result
in problems (crashes, non-termination) if the inputs do contain
infinities or NaNs.
Returns:
If compute_expm is True:
expm_A: ndarray. Matrix exponential of A.
expm_frechet_AE: ndarray. Frechet derivative of the matrix exponential of A
in the direction E.
Otherwise:
expm_frechet_AE: ndarray. Frechet derivative of the matrix exponential of A
in the direction E.
"""
if check_finite:
if not torch.isfinite(A).all():
raise ValueError("Matrix A contains non-finite values")
if not torch.isfinite(E).all():
raise ValueError("Matrix E contains non-finite values")
# Convert inputs to torch tensors if they aren't already
if not isinstance(A, torch.Tensor):
A = torch.tensor(A, dtype=torch.float64)
if not isinstance(E, torch.Tensor):
E = torch.tensor(E, dtype=torch.float64)
if A.dim() != 2 or A.shape[0] != A.shape[1]:
raise ValueError("expected A to be a square matrix")
if E.dim() != 2 or E.shape[0] != E.shape[1]:
raise ValueError("expected E to be a square matrix")
if A.shape != E.shape:
raise ValueError("expected A and E to be the same shape")
if method is None:
method = "SPS"
if method == "SPS":
expm_A, expm_frechet_AE = expm_frechet_algo_64(A, E)
elif method == "blockEnlarge":
expm_A, expm_frechet_AE = expm_frechet_block_enlarge(A, E)
else:
raise ValueError(f"Unknown implementation {method}")
if compute_expm:
return expm_A, expm_frechet_AE
return expm_frechet_AE
[docs]
def expm_frechet_block_enlarge(
A: torch.Tensor, E: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""Helper function for testing and profiling.
Args:
A: Input matrix
E: Direction matrix
Returns:
expm_A: Matrix exponential of A
expm_frechet_AE: torch.Tensor
Frechet derivative of the matrix exponential of A in the direction E
"""
n = A.shape[0]
# Create block matrix M = [[A, E], [0, A]]
M = torch.zeros((2 * n, 2 * n), dtype=A.dtype, device=A.device)
M[:n, :n] = A
M[:n, n:] = E
M[n:, n:] = A
# Use matrix exponential
expm_M = matrix_exp(M)
return expm_M[:n, :n], expm_M[:n, n:]
# Maximal values ell_m of ||2**-s A|| such that the backward error bound
# does not exceed 2**-53.
ell_table_61: Final = (
None,
# 1
2.11e-8,
3.56e-4,
1.08e-2,
6.49e-2,
2.00e-1,
4.37e-1,
7.83e-1,
1.23e0,
1.78e0,
2.42e0,
# 11
3.13e0,
3.90e0,
4.74e0,
5.63e0,
6.56e0,
7.52e0,
8.53e0,
9.56e0,
1.06e1,
1.17e1,
)
def _diff_pade3(
A: torch.Tensor, E: torch.Tensor, ident: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute Padé approximation of order 3 for matrix exponential and
its Frechet derivative.
Args:
A: Input matrix
E: Direction matrix
ident: Identity matrix of same shape as A
Returns:
U, V, Lu, Lv: Components needed for computing the matrix exponential and
its Frechet derivative
"""
b = (120.0, 60.0, 12.0, 1.0)
A2 = torch.matmul(A, A)
M2 = torch.matmul(A, E) + torch.matmul(E, A)
U = torch.matmul(A, b[3] * A2 + b[1] * ident)
V = b[2] * A2 + b[0] * ident
Lu = torch.matmul(A, b[3] * M2) + torch.matmul(E, b[3] * A2 + b[1] * ident)
Lv = b[2] * M2
return U, V, Lu, Lv
def _diff_pade5(
A: torch.Tensor, E: torch.Tensor, ident: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute Padé approximation of order 5 for matrix exponential and
its Frechet derivative.
Args:
A: Input matrix
E: Direction matrix
ident: Identity matrix of same shape as A
Returns:
U, V, Lu, Lv: Components needed for computing the matrix exponential and
its Frechet derivative
"""
b = (30240.0, 15120.0, 3360.0, 420.0, 30.0, 1.0)
A2 = torch.matmul(A, A)
M2 = torch.matmul(A, E) + torch.matmul(E, A)
A4 = torch.matmul(A2, A2)
M4 = torch.matmul(A2, M2) + torch.matmul(M2, A2)
U = torch.matmul(A, b[5] * A4 + b[3] * A2 + b[1] * ident)
V = b[4] * A4 + b[2] * A2 + b[0] * ident
Lu = torch.matmul(A, b[5] * M4 + b[3] * M2) + torch.matmul(
E, b[5] * A4 + b[3] * A2 + b[1] * ident
)
Lv = b[4] * M4 + b[2] * M2
return U, V, Lu, Lv
def _diff_pade7(
A: torch.Tensor, E: torch.Tensor, ident: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute Padé approximation of order 7 for matrix exponential and
its Frechet derivative.
Args:
A: Input matrix
E: Direction matrix
ident: Identity matrix of same shape as A
Returns:
U, V, Lu, Lv: Components needed for computing the matrix exponential and
its Frechet derivative
"""
b = (17297280.0, 8648640.0, 1995840.0, 277200.0, 25200.0, 1512.0, 56.0, 1.0)
A2 = torch.matmul(A, A)
M2 = torch.matmul(A, E) + torch.matmul(E, A)
A4 = torch.matmul(A2, A2)
M4 = torch.matmul(A2, M2) + torch.matmul(M2, A2)
A6 = torch.matmul(A2, A4)
M6 = torch.matmul(A4, M2) + torch.matmul(M4, A2)
U = torch.matmul(A, b[7] * A6 + b[5] * A4 + b[3] * A2 + b[1] * ident)
V = b[6] * A6 + b[4] * A4 + b[2] * A2 + b[0] * ident
Lu = torch.matmul(A, b[7] * M6 + b[5] * M4 + b[3] * M2) + torch.matmul(
E, b[7] * A6 + b[5] * A4 + b[3] * A2 + b[1] * ident
)
Lv = b[6] * M6 + b[4] * M4 + b[2] * M2
return U, V, Lu, Lv
def _diff_pade9(
A: torch.Tensor, E: torch.Tensor, ident: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute Padé approximation of order 9 for matrix exponential and
its Frechet derivative.
Args:
A: Input matrix
E: Direction matrix
ident: Identity matrix of same shape as A
Returns:
U, V, Lu, Lv: Components needed for computing the matrix exponential and
its Frechet derivative
"""
b = (
17643225600.0,
8821612800.0,
2075673600.0,
302702400.0,
30270240.0,
2162160.0,
110880.0,
3960.0,
90.0,
1.0,
)
A2 = torch.matmul(A, A)
M2 = torch.matmul(A, E) + torch.matmul(E, A)
A4 = torch.matmul(A2, A2)
M4 = torch.matmul(A2, M2) + torch.matmul(M2, A2)
A6 = torch.matmul(A2, A4)
M6 = torch.matmul(A4, M2) + torch.matmul(M4, A2)
A8 = torch.matmul(A4, A4)
M8 = torch.matmul(A4, M4) + torch.matmul(M4, A4)
U = torch.matmul(A, b[9] * A8 + b[7] * A6 + b[5] * A4 + b[3] * A2 + b[1] * ident)
V = b[8] * A8 + b[6] * A6 + b[4] * A4 + b[2] * A2 + b[0] * ident
Lu = torch.matmul(A, b[9] * M8 + b[7] * M6 + b[5] * M4 + b[3] * M2) + torch.matmul(
E, b[9] * A8 + b[7] * A6 + b[5] * A4 + b[3] * A2 + b[1] * ident
)
Lv = b[8] * M8 + b[6] * M6 + b[4] * M4 + b[2] * M2
return U, V, Lu, Lv
[docs]
def expm_frechet_algo_64(
A: torch.Tensor, E: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute matrix exponential and its Frechet derivative using Algorithm 6.4.
This implementation follows Al-Mohy and Higham's Algorithm 6.4 from
"Computing the Frechet Derivative of the Matrix Exponential, with an
application to Condition Number Estimation".
Args:
A: Input matrix
E: Direction matrix
Returns:
R: Matrix exponential of A
L: Frechet derivative of the matrix exponential in the direction E
"""
n = A.shape[0]
s = None
ident = torch.eye(n, dtype=A.dtype, device=A.device)
A_norm_1 = torch.norm(A, p=1)
m_pade_pairs = (
(3, _diff_pade3),
(5, _diff_pade5),
(7, _diff_pade7),
(9, _diff_pade9),
)
for m, pade in m_pade_pairs:
if A_norm_1 <= ell_table_61[m]:
U, V, Lu, Lv = pade(A, E, ident)
s = 0
break
if s is None:
# scaling
s = max(0, int(torch.ceil(torch.log2(A_norm_1 / ell_table_61[13]))))
A = A * 2.0**-s
E = E * 2.0**-s
# pade order 13
A2 = torch.matmul(A, A)
M2 = torch.matmul(A, E) + torch.matmul(E, A)
A4 = torch.matmul(A2, A2)
M4 = torch.matmul(A2, M2) + torch.matmul(M2, A2)
A6 = torch.matmul(A2, A4)
M6 = torch.matmul(A4, M2) + torch.matmul(M4, A2)
b = (
64764752532480000.0,
32382376266240000.0,
7771770303897600.0,
1187353796428800.0,
129060195264000.0,
10559470521600.0,
670442572800.0,
33522128640.0,
1323241920.0,
40840800.0,
960960.0,
16380.0,
182.0,
1.0,
)
W1 = b[13] * A6 + b[11] * A4 + b[9] * A2
W2 = b[7] * A6 + b[5] * A4 + b[3] * A2 + b[1] * ident
Z1 = b[12] * A6 + b[10] * A4 + b[8] * A2
Z2 = b[6] * A6 + b[4] * A4 + b[2] * A2 + b[0] * ident
W = torch.matmul(A6, W1) + W2
U = torch.matmul(A, W)
V = torch.matmul(A6, Z1) + Z2
Lw1 = b[13] * M6 + b[11] * M4 + b[9] * M2
Lw2 = b[7] * M6 + b[5] * M4 + b[3] * M2
Lz1 = b[12] * M6 + b[10] * M4 + b[8] * M2
Lz2 = b[6] * M6 + b[4] * M4 + b[2] * M2
Lw = torch.matmul(A6, Lw1) + torch.matmul(M6, W1) + Lw2
Lu = torch.matmul(A, Lw) + torch.matmul(E, W)
Lv = torch.matmul(A6, Lz1) + torch.matmul(M6, Z1) + Lz2
# Solve the system (-U + V)X = (U + V) for R
R = torch.linalg.solve(-U + V, U + V)
# Solve the system (-U + V)X = (Lu + Lv + (Lu - Lv)R) for L
L = torch.linalg.solve(-U + V, Lu + Lv + torch.matmul(Lu - Lv, R))
# squaring
for _ in range(s):
L = torch.matmul(R, L) + torch.matmul(L, R)
R = torch.matmul(R, R)
return R, L
[docs]
def matrix_exp(A: torch.Tensor) -> torch.Tensor:
"""Compute the matrix exponential of A using PyTorch's matrix_exp.
Args:
A: Input matrix
Returns:
Matrix exponential of A
"""
return torch.matrix_exp(A)
[docs]
def vec(M: torch.Tensor) -> torch.Tensor:
"""Stack columns of M to construct a single vector.
This is somewhat standard notation in linear algebra.
Args:
M: Input matrix
Returns:
Output vector
"""
return M.t().reshape(-1)
[docs]
def expm_cond(A: torch.Tensor, check_finite: bool = True) -> torch.Tensor:
"""Relative condition number of the matrix exponential in the Frobenius norm.
Args:
A: Square input matrix with shape (N, N)
check_finite: Whether to check that the input matrix contains only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
Returns:
kappa: The relative condition number of the matrix exponential
in the Frobenius norm
"""
if check_finite and not torch.isfinite(A).all():
raise ValueError("Matrix A contains non-finite values")
# Convert input to torch tensor if it isn't already
if not isinstance(A, torch.Tensor):
A = torch.tensor(A, dtype=torch.float64)
if A.dim() != 2 or A.shape[0] != A.shape[1]:
raise ValueError("expected a square matrix")
X = matrix_exp(A)
K = expm_frechet_kronform(A, check_finite=False)
# The following norm choices are deliberate.
# The norms of A and X are Frobenius norms,
# and the norm of K is the induced 2-norm.
norm_p = "fro"
A_norm = torch.norm(A, p=norm_p)
X_norm = torch.norm(X, p=norm_p)
K_norm = torch.linalg.matrix_norm(K, ord=2)
# kappa
return (K_norm * A_norm) / X_norm
[docs]
class expm(Function): # noqa: N801
"""Compute the matrix exponential of a matrix or batch of matrices."""
[docs]
@staticmethod
def forward(ctx: Any, A: torch.Tensor) -> torch.Tensor:
"""Compute the matrix exponential of A.
Args:
ctx: ctx
A: Input matrix or batch of matrices
Returns:
Matrix exponential of A
"""
# Save A for backward pass
ctx.save_for_backward(A)
# Use the matrix_exp function we already have
return matrix_exp(A)
[docs]
@staticmethod
def backward(ctx: Any, grad_output: torch.Tensor) -> torch.Tensor:
"""Compute the gradient of matrix exponential.
Args:
ctx: ctx
grad_output: Gradient with respect to the output
Returns:
Gradient with respect to the input
"""
# Retrieve saved tensor
(A,) = ctx.saved_tensors
# Compute the Frechet derivative in the direction of grad_output
return expm_frechet(
A, grad_output, method="SPS", compute_expm=False, check_finite=False
)
def _is_valid_matrix(T: torch.Tensor, n: int = 3) -> bool:
"""Check if T is a valid nxn matrix.
Args:
T: The matrix to check
n: The expected dimension of the matrix, default=3
Returns:
True if T is a valid nxn tensor, False otherwise
"""
return isinstance(T, torch.Tensor) and T.shape == (n, n)
def _determine_eigenvalue_case( # noqa: C901
T: torch.Tensor, eigenvalues: torch.Tensor, num_tol: float = 1e-16
) -> str:
"""Determine the eigenvalue structure case of matrix T.
Args:
T: The 3x3 matrix to analyze
eigenvalues: The eigenvalues of T
num_tol: Numerical tolerance for comparing eigenvalues, default=1e-16
Returns:
The case identifier ("case1a", "case1b", etc.)
Raises:
ValueError: If the eigenvalue structure cannot be determined
"""
# Get unique values and their counts directly with one call
unique_vals, counts = torch.unique(eigenvalues, return_counts=True)
# Use np.isclose to group eigenvalues that are numerically close
# We can create a mask for each unique value to see if other values are close to it
if len(unique_vals) > 1:
# Check if some "unique" values should actually be considered the same
i = 0
while i < len(unique_vals):
# Find all values close to the current one
close_mask = torch.isclose(unique_vals, unique_vals[i], rtol=0, atol=num_tol)
close_count = torch.sum(close_mask)
if close_count > 1: # If there are other close values
# Merge them (keep the first one, remove the others)
counts[i] = torch.sum(counts[close_mask])
unique_vals = unique_vals[
~(close_mask & torch.arange(len(close_mask)) != i)
]
counts = counts[~(close_mask & torch.arange(len(counts)) != i)]
else:
i += 1
# Now determine the case based on the number of unique eigenvalues
if len(unique_vals) == 1:
# Case 1: All eigenvalues are equal (λ, λ, λ)
lambda_val = unique_vals[0]
Identity = torch.eye(3, dtype=lambda_val.dtype, device=lambda_val.device)
T_minus_lambdaI = T - lambda_val * Identity
rank1 = torch.linalg.matrix_rank(T_minus_lambdaI)
if rank1 == 0:
return "case1a" # q(T) = (T - λI)
rank2 = torch.linalg.matrix_rank(T_minus_lambdaI @ T_minus_lambdaI)
if rank2 == 0:
return "case1b" # q(T) = (T - λI)²
return "case1c" # q(T) = (T - λI)³
if len(unique_vals) == 2:
# Case 2: Two distinct eigenvalues
# The counts array already tells us which eigenvalue is repeated
if counts.max() != 2 or counts.min() != 1:
raise ValueError("Unexpected eigenvalue pattern for Case 2")
mu = unique_vals[torch.argmin(counts)] # The non-repeated eigenvalue
lambda_val = unique_vals[torch.argmax(counts)] # The repeated eigenvalue
Identity = torch.eye(3, dtype=lambda_val.dtype, device=lambda_val.device)
T_minus_muI = T - mu * Identity
T_minus_lambdaI = T - lambda_val * Identity
# Check if (T - μI)(T - λI) annihilates T
if torch.allclose(
T_minus_muI @ T_minus_lambdaI @ T,
torch.zeros((3, 3), dtype=lambda_val.dtype, device=lambda_val.device),
):
return "case2a" # q(T) = (T - λI)(T - μI)
return "case2b" # q(T) = (T - μI)(T - λI)²
if len(unique_vals) == 3:
# Case 3: Three distinct eigenvalues (λ, μ, ν)
return "case3" # q(T) = (T - λI)(T - μI)(T - νI)
raise ValueError("Could not determine eigenvalue structure")
def _matrix_log_case1a(T: torch.Tensor, lambda_val: complex) -> torch.Tensor:
"""Compute log(T) when q(T) = (T - λI).
This is the case where T is a scalar multiple of the identity matrix.
Args:
T: The matrix whose logarithm is to be computed
lambda_val: The eigenvalue of T
Returns:
The logarithm of T, which is log(λ)·I
"""
n = T.shape[0]
Identity = torch.eye(n, dtype=lambda_val.dtype, device=lambda_val.device)
return torch.log(lambda_val) * Identity
def _matrix_log_case1b(
T: torch.Tensor, lambda_val: complex, num_tol: float = 1e-16
) -> torch.Tensor:
"""Compute log(T) when q(T) = (T - λI)².
This is the case where T has a Jordan block of size 2.
Args:
T: The matrix whose logarithm is to be computed
lambda_val: The eigenvalue of T
num_tol: Numerical tolerance for stability checks, default=1e-16
Returns:
The logarithm of T
"""
n = T.shape[0]
Identity = torch.eye(n, dtype=lambda_val.dtype, device=lambda_val.device)
T_minus_lambdaI = T - lambda_val * Identity
# For numerical stability, scale appropriately
if abs(lambda_val) > 1:
scaled_T_minus_lambdaI = T_minus_lambdaI / lambda_val
return torch.log(lambda_val) * Identity + scaled_T_minus_lambdaI
# Alternative computation for small lambda
return torch.log(lambda_val) * Identity + T_minus_lambdaI / max(lambda_val, num_tol)
def _matrix_log_case1c(
T: torch.Tensor, lambda_val: complex, num_tol: float = 1e-16
) -> torch.Tensor:
"""Compute log(T) when q(T) = (T - λI)³.
This is the case where T has a Jordan block of size 3.
Args:
T: The matrix whose logarithm is to be computed
lambda_val: The eigenvalue of T
num_tol: Numerical tolerance for stability checks, default=1e-16
Returns:
The logarithm of T
"""
n = T.shape[0]
Identity = torch.eye(n, dtype=lambda_val.dtype, device=lambda_val.device)
T_minus_lambdaI = T - lambda_val * Identity
# Compute (T - λI)² with better numerical stability
T_minus_lambdaI_squared = T_minus_lambdaI @ T_minus_lambdaI
# For numerical stability
lambda_squared = lambda_val * lambda_val
term1 = torch.log(lambda_val) * Identity
term2 = T_minus_lambdaI / max(lambda_val, num_tol)
term3 = T_minus_lambdaI_squared / max(2 * lambda_squared, num_tol)
return term1 + term2 - term3
def _matrix_log_case2a(
T: torch.Tensor, lambda_val: complex, mu: complex, num_tol: float = 1e-16
) -> torch.Tensor:
"""Compute log(T) when q(T) = (T - λI)(T - μI) with λ≠μ.
This is the case with two distinct eigenvalues.
Formula: log T = log μ((T - λI)/(μ - λ)) + log λ((T - μI)/(λ - μ))
Args:
T: The matrix whose logarithm is to be computed
lambda_val: The repeated eigenvalue of T
mu: The non-repeated eigenvalue of T
num_tol: Numerical tolerance for stability checks, default=1e-16
Returns:
The logarithm of T
Raises:
ValueError: If λ and μ are too close for numerical stability
"""
n = T.shape[0]
Identity = torch.eye(n, dtype=lambda_val.dtype, device=lambda_val.device)
lambda_minus_mu = lambda_val - mu
# Check for numerical stability
if torch.abs(lambda_minus_mu) < num_tol:
raise ValueError("λ and μ are too close, computation may be unstable")
T_minus_lambdaI = T - lambda_val * Identity
T_minus_muI = T - mu * Identity
# Compute each term separately for better numerical stability
term1 = torch.log(mu) * (T_minus_lambdaI / (mu - lambda_val))
term2 = torch.log(lambda_val) * (T_minus_muI / (lambda_val - mu))
return term1 + term2
def _matrix_log_case2b(
T: torch.Tensor, lambda_val: complex, mu: complex, num_tol: float = 1e-16
) -> torch.Tensor:
"""Compute log(T) when q(T) = (T - μI)(T - λI)² with λ≠μ.
This is the case with one eigenvalue of multiplicity 2 and one distinct eigenvalue.
Formula: log T = log μ((T - λI)²/(λ - μ)²) -
log λ((T - μI)(T - (2λ - μ)I)/(λ - μ)²) +
((T - λI)(T - μI)/(λ(λ - μ)))
Args:
T: The matrix whose logarithm is to be computed
lambda_val: The repeated eigenvalue of T
mu: The non-repeated eigenvalue of T
num_tol: Numerical tolerance for stability checks, default=1e-16
Returns:
The logarithm of T
Raises:
ValueError: If λ and μ are too close for numerical stability or
if λ is too close to zero
"""
n = T.shape[0]
Identity = torch.eye(n, dtype=lambda_val.dtype, device=lambda_val.device)
lambda_minus_mu = lambda_val - mu
lambda_minus_mu_squared = lambda_minus_mu * lambda_minus_mu
# Check for numerical stability
if torch.abs(lambda_minus_mu) < num_tol:
raise ValueError("λ and μ are too close, computation may be unstable")
if torch.abs(lambda_val) < num_tol:
raise ValueError("λ is too close to zero, computation may be unstable")
T_minus_lambdaI = T - lambda_val * Identity
T_minus_muI = T - mu * Identity
T_minus_lambdaI_squared = T_minus_lambdaI @ T_minus_lambdaI
# The term (T - (2λ - μ)I)
T_minus_2lambda_plus_muI = T - (2 * lambda_val - mu) * Identity
# Compute each term separately for better numerical stability
term1 = torch.log(mu) * (T_minus_lambdaI_squared / lambda_minus_mu_squared)
term2 = -torch.log(lambda_val) * (
(T_minus_muI @ T_minus_2lambda_plus_muI) / lambda_minus_mu_squared
)
term3 = (T_minus_lambdaI @ T_minus_muI) / (lambda_val * lambda_minus_mu)
return term1 + term2 + term3
def _matrix_log_case3(
T: torch.Tensor, lambda_val: complex, mu: complex, nu: complex, num_tol: float = 1e-16
) -> torch.Tensor:
"""Compute log(T) when q(T) = (T - λI)(T - μI)(T - νI) with λ≠μ≠ν≠λ.
This is the case with three distinct eigenvalues.
Formula: log T = log λ((T - μI)(T - νI)/((λ - μ)(λ - ν)))
+ log μ((T - λI)(T - νI)/((μ - λ)(μ - ν)))
+ log ν((T - λI)(T - μI)/((ν - λ)(ν - μ)))
Args:
T: The matrix whose logarithm is to be computed
lambda_val: First eigenvalue of T
mu: Second eigenvalue of T
nu: Third eigenvalue of T
num_tol: Numerical tolerance for stability checks, default=1e-6
Returns:
The logarithm of T
Raises:
ValueError: If any pair of eigenvalues are too close for numerical stability
"""
n = T.shape[0]
Identity = torch.eye(n, dtype=lambda_val.dtype, device=lambda_val.device)
# Check if eigenvalues are distinct enough for numerical stability
if (
min(torch.abs(lambda_val - mu), torch.abs(lambda_val - nu), torch.abs(mu - nu))
< num_tol
):
raise ValueError("Eigenvalues are too close, computation may be unstable")
T_minus_lambdaI = T - lambda_val * Identity
T_minus_muI = T - mu * Identity
T_minus_nuI = T - nu * Identity
# Compute the terms for λ
lambda_term_numerator = T_minus_muI @ T_minus_nuI
lambda_term_denominator = (lambda_val - mu) * (lambda_val - nu)
lambda_term = torch.log(lambda_val) * (
lambda_term_numerator / lambda_term_denominator
)
# Compute the terms for μ
mu_term_numerator = T_minus_lambdaI @ T_minus_nuI
mu_term_denominator = (mu - lambda_val) * (mu - nu)
mu_term = torch.log(mu) * (mu_term_numerator / mu_term_denominator)
# Compute the terms for ν
nu_term_numerator = T_minus_lambdaI @ T_minus_muI
nu_term_denominator = (nu - lambda_val) * (nu - mu)
nu_term = torch.log(nu) * (nu_term_numerator / nu_term_denominator)
return lambda_term + mu_term + nu_term
def _matrix_log_33( # noqa: C901
T: torch.Tensor, case: str = "auto", dtype: torch.dtype = torch.float64
) -> torch.Tensor:
"""Compute the logarithm of 3x3 matrix T based on its eigenvalue structure.
The logarithm of this matrix is known exactly as given the in the references.
Args:
T: The matrix whose logarithm is to be computed
case: One of "auto", "case1a", "case1b", "case1c", "case2a", "case2b", "case3"
- "auto": Automatically determine the structure
- "case1a": All eigenvalues are equal, q(T) = (T - λI)
- "case1b": All eigenvalues are equal, q(T) = (T - λI)²
- "case1c": All eigenvalues are equal, q(T) = (T - λI)³
- "case2a": Two distinct eigenvalues, q(T) = (T - λI)(T - μI)
- "case2b": Two distinct eigenvalues, q(T) = (T - μI)(T - λI)²
- "case3": Three distinct eigenvalues, q(T) = (T - λI)(T - μI)(T - νI)
dtype: The data type to use for numerical tolerance, default=torch.float64
Returns:
The logarithm of T
References:
- https://link.springer.com/article/10.1007/s10659-008-9169-x
"""
num_tol = 1e-16 if dtype == torch.float64 else 1e-8
if not _is_valid_matrix(T):
raise ValueError("Input must be a 3x3 matrix")
# Compute eigenvalues
eigenvalues = torch.linalg.eigvals(T)
# Convert eigenvalues to real if they're complex but with tiny imaginary parts
eigenvalues = (
torch.real(eigenvalues)
if torch.allclose(
torch.imag(eigenvalues),
torch.zeros_like(torch.imag(eigenvalues)),
atol=num_tol,
)
else eigenvalues
)
# If automatic detection, determine the structure
if case == "auto":
case = _determine_eigenvalue_case(T, eigenvalues, num_tol)
# Case 1: All eigenvalues are equal (λ, λ, λ)
if case in ["case1a", "case1b", "case1c"]:
lambda_val = eigenvalues[0]
# Check for numerical stability
if torch.abs(lambda_val) < num_tol:
raise ValueError("Eigenvalue too close to zero, computation may be unstable")
if case == "case1a":
return _matrix_log_case1a(T, lambda_val)
if case == "case1b":
return _matrix_log_case1b(T, lambda_val, num_tol)
if case == "case1c":
return _matrix_log_case1c(T, lambda_val, num_tol)
# Case 2: Two distinct eigenvalues (μ, λ, λ)
elif case in ["case2a", "case2b"]:
# Find the unique eigenvalue (μ) and the repeated eigenvalue (λ)
unique_vals, counts = torch.unique(
torch.round(eigenvalues, decimals=10), return_counts=True
)
if len(unique_vals) != 2 or counts.max() != 2:
raise ValueError(
"Case 2 requires exactly two distinct eigenvalues with one repeated"
)
mu = unique_vals[torch.argmin(counts)] # The non-repeated eigenvalue
lambda_val = unique_vals[torch.argmax(counts)] # The repeated eigenvalue
if case == "case2a":
return _matrix_log_case2a(T, lambda_val, mu, num_tol)
if case == "case2b":
return _matrix_log_case2b(T, lambda_val, mu, num_tol)
# Case 3: Three distinct eigenvalues (λ, μ, ν)
elif case == "case3":
if len(torch.unique(torch.round(eigenvalues, decimals=10))) != 3:
raise ValueError("Case 3 requires three distinct eigenvalues")
lambda_val, mu, nu = torch.sort(eigenvalues).values # Sort for consistency
return _matrix_log_case3(T, lambda_val, mu, nu, num_tol)
else:
raise ValueError(f"Unknown eigenvalue {case=}")
[docs]
def matrix_log_scipy(matrix: torch.Tensor) -> torch.Tensor:
"""Compute the matrix logarithm of a square matrix using scipy.linalg.logm.
This function handles tensors on CPU or GPU and preserves gradients.
Args:
matrix: A square matrix tensor
Returns:
torch.Tensor: The matrix logarithm of the input matrix
"""
import scipy.linalg
# Save original device and dtype
device = matrix.device
dtype = matrix.dtype
requires_grad = matrix.requires_grad
# Detach and move to CPU for scipy
matrix_cpu = matrix.detach().cpu().numpy()
# Compute the logarithm using scipy
result_np = scipy.linalg.logm(matrix_cpu)
# Convert back to tensor and move to original device
result = torch.tensor(result_np, dtype=dtype, device=device)
# If input requires gradient, make the output require gradient too
if requires_grad:
result = result.requires_grad_()
return result
[docs]
def matrix_log_33(
matrix: torch.Tensor,
sim_dtype: torch.dtype = torch.float64,
fallback_warning: bool = False,
) -> torch.Tensor:
"""Compute the matrix logarithm of a square 3x3 matrix.
Args:
matrix: A square 3x3 matrix tensor
sim_dtype: Simulation dtype, default=torch.float64
fallback_warning: Whether to print a warning when falling back to scipy,
default=False
Returns:
The matrix logarithm of the input matrix
This function attempts to use the exact formula for 3x3 matrices first,
and falls back to scipy implementation if that fails.
"""
# Convert to double precision for stability
matrix = matrix.to(torch.float64)
try:
return _matrix_log_33(matrix).to(sim_dtype)
except (ValueError, RuntimeError) as exc:
msg = (
f"Error computing matrix logarithm with _matrix_log_33 {exc} \n"
"Falling back to scipy"
)
if fallback_warning:
print(msg)
# Fall back to scipy implementation
return matrix_log_scipy(matrix).to(sim_dtype)