# ruff: noqa: RUF002, RUF003, PLC2401
"""Calculation of elastic properties of crystals.
Primary Sources and References for Crystal Elasticity.
- Landau, L.D. & Lifshitz, E.M. "Theory of Elasticity" (Volume 7 of Course of
Theoretical Physics)
- Teodosiu, C. (1982) "Elastic Models of Crystal Defects"
Review Articles:
- Mouhat, F., & Coudert, F. X. (2014).
"Necessary and sufficient elastic stability conditions in various crystal systems"
Physical Review B, 90(22), 224104
Online Resources:
- Materials Project Documentation
https://docs.materialsproject.org/methodology/elasticity/
"""
from collections.abc import Callable
from dataclasses import dataclass
import torch
from torch_sim.state import SimState
from torch_sim.typing import BravaisType
[docs]
def get_bravais_type( # noqa: PLR0911
state: SimState, length_tol: float = 1e-3, angle_tol: float = 0.1
) -> BravaisType:
"""Check and return the crystal system of a structure.
This function determines the crystal system by analyzing the lattice
parameters and angles without using spglib.
Args:
state: SimState object representing the crystal structure
length_tol: Tolerance for floating-point comparisons of lattice lengths
angle_tol: Tolerance for floating-point comparisons of lattice angles in degrees
Returns:
BravaisType: Bravais type
"""
# Get cell parameters
row_vector_cell = state.row_vector_cell.squeeze()
a, b, c = torch.linalg.norm(row_vector_cell, axis=1)
# Get cell angles in degrees
alpha = torch.rad2deg(
torch.arccos(torch.dot(row_vector_cell[1], row_vector_cell[2]) / (b * c))
)
beta = torch.rad2deg(
torch.arccos(torch.dot(row_vector_cell[0], row_vector_cell[2]) / (a * c))
)
gamma = torch.rad2deg(
torch.arccos(torch.dot(row_vector_cell[0], row_vector_cell[1]) / (a * b))
)
# Cubic: a = b = c, alpha = beta = gamma = 90°
if (
abs(a - b) < length_tol
and abs(b - c) < length_tol
and abs(alpha - 90) < angle_tol
and abs(beta - 90) < angle_tol
and abs(gamma - 90) < angle_tol
):
return BravaisType.CUBIC
# Hexagonal: a = b ≠ c, alpha = beta = 90°, gamma = 120°
if (
abs(a - b) < length_tol
and abs(alpha - 90) < angle_tol
and abs(beta - 90) < angle_tol
and abs(gamma - 120) < angle_tol
):
return BravaisType.HEXAGONAL
# Tetragonal: a = b ≠ c, alpha = beta = gamma = 90°
if (
abs(a - b) < length_tol
and abs(a - c) > length_tol
and abs(alpha - 90) < angle_tol
and abs(beta - 90) < angle_tol
and abs(gamma - 90) < angle_tol
):
return BravaisType.TETRAGONAL
# Orthorhombic: a ≠ b ≠ c, alpha = beta = gamma = 90°
if (
abs(alpha - 90) < angle_tol
and abs(beta - 90) < angle_tol
and abs(gamma - 90) < angle_tol
and abs(a - b) > length_tol
and (abs(b - c) > length_tol or abs(a - c) > length_tol)
):
return BravaisType.ORTHORHOMBIC
# Monoclinic: a ≠ b ≠ c, alpha = gamma = 90°, beta ≠ 90°
if (
abs(alpha - 90) < angle_tol
and abs(gamma - 90) < angle_tol
and abs(beta - 90) > angle_tol
):
return BravaisType.MONOCLINIC
# Trigonal/Rhombohedral: a = b = c, alpha = beta = gamma ≠ 90°
if (
abs(a - b) < length_tol
and abs(b - c) < length_tol
and abs(alpha - beta) < angle_tol
and abs(beta - gamma) < angle_tol
and abs(alpha - 90) > angle_tol
):
return BravaisType.TRIGONAL
# Triclinic: a ≠ b ≠ c, alpha ≠ beta ≠ gamma ≠ 90°
return BravaisType.TRICLINIC
[docs]
def regular_symmetry(strains: torch.Tensor) -> torch.Tensor:
"""Generate equation matrix for cubic (regular) crystal symmetry.
Constructs the stress-strain relationship matrix for cubic symmetry,
which has three independent elastic constants: C11, C12, and C44.
The matrix relates strains to stresses according to the equation:
σᵢ = Σⱼ Cᵢⱼ εⱼ
Args:
strains: Tensor of shape (6,) containing strain components
[εxx, εyy, εzz, εyz, εxz, εxy]
where:
- εxx, εyy, εzz are normal strains
- εyz, εxz, εxy are shear strains
Returns:
torch.Tensor: Matrix of shape (6, 3) where columns correspond to
coefficients for C11, C12, and C44 respectively
Notes:
The resulting matrix M has the form:
⎡ εxx (εyy + εzz) 0 ⎤
⎢ εyy (εxx + εzz) 0 ⎥
⎢ εzz (εxx + εyy) 0 ⎥
⎢ 0 0 2εyz ⎥
⎢ 0 0 2εxz ⎥
⎣ 0 0 2εxy ⎦
This represents the relationship:
σxx = C11*εxx + C12*(εyy + εzz)
σyy = C11*εyy + C12*(εxx + εzz)
σzz = C11*εzz + C12*(εxx + εyy)
σyz = 2*C44*εyz
σxz = 2*C44*εxz
σxy = 2*C44*εxy
"""
if not isinstance(strains, torch.Tensor):
strains = torch.tensor(strains)
if strains.shape != (6,):
raise ValueError("Strains tensor must have shape (6,)")
# Unpack strain components
εxx, εyy, εzz, εyz, εxz, εxy = strains.unbind()
# Create the matrix using torch.zeros for proper device/dtype handling
matrix = torch.zeros((6, 3), dtype=strains.dtype, device=strains.device)
# First column
matrix[0, 0] = εxx
matrix[1, 0] = εyy
matrix[2, 0] = εzz
# Second column
matrix[0, 1] = εyy + εzz
matrix[1, 1] = εxx + εzz
matrix[2, 1] = εxx + εyy
# Third column
matrix[3, 2] = 2 * εyz
matrix[4, 2] = 2 * εxz
matrix[5, 2] = 2 * εxy
return matrix
[docs]
def tetragonal_symmetry(strains: torch.Tensor) -> torch.Tensor:
"""Generate equation matrix for tetragonal crystal symmetry.
Constructs the stress-strain relationship matrix for tetragonal symmetry,
which has 7 independent elastic constants: C11, C12, C13, C16, C33, C44, C66.
Args:
strains: Tensor of shape (6,) containing strain components
[εxx, εyy, εzz, εyz, εxz, εxy]
where:
- εxx, εyy, εzz are normal strains
- εyz, εxz, εxy are shear strains
Returns:
torch.Tensor: Matrix of shape (6, 7) where columns correspond to
coefficients for C11, C12, C13, C16, C33, C44, C66
Notes:
The resulting matrix M has the form:
⎡ εxx εyy εzz 2εxy 0 0 0 ⎤
⎢ εyy εxx εzz -2εxy 0 0 0 ⎥
⎢ 0 0 εxx+εyy 0 εzz 0 0 ⎥
⎢ 0 0 0 0 0 2εyz 0 ⎥
⎢ 0 0 0 0 0 2εxz 0 ⎥
⎣ 0 0 0 εxx-εyy 0 0 2εxy ⎦
"""
if not isinstance(strains, torch.Tensor):
strains = torch.tensor(strains)
if strains.shape != (6,):
raise ValueError("Strains tensor must have shape (6,)")
# Unpack strain components
εxx, εyy, εzz, εyz, εxz, εxy = strains.unbind()
# Create the matrix using torch.zeros for proper device/dtype handling
matrix = torch.zeros((6, 7), dtype=strains.dtype, device=strains.device)
# First row
matrix[0, 0] = εxx
matrix[0, 1] = εyy
matrix[0, 2] = εzz
matrix[0, 3] = 2 * εxy
# Second row
matrix[1, 0] = εyy
matrix[1, 1] = εxx
matrix[1, 2] = εzz
matrix[1, 3] = -2 * εxy
# Third row
matrix[2, 2] = εxx + εyy
matrix[2, 4] = εzz
# Fourth and fifth rows
matrix[3, 5] = 2 * εyz
matrix[4, 5] = 2 * εxz
# Sixth row
matrix[5, 3] = εxx - εyy
matrix[5, 6] = 2 * εxy
return matrix
[docs]
def orthorhombic_symmetry(strains: torch.Tensor) -> torch.Tensor:
"""Generate equation matrix for orthorhombic crystal symmetry.
Constructs the stress-strain relationship matrix for orthorhombic symmetry,
which has nine independent elastic constants: C11, C12, C13, C22, C23, C33,
C44, C55, and C66.
Args:
strains: Tensor of shape (6,) containing strain components
[εxx, εyy, εzz, εyz, εxz, εxy]
Returns:
torch.Tensor: Matrix of shape (6, 9) where columns correspond to
coefficients for C11, C12, C13, C22, C23, C33, C44, C55, C66
Notes:
The resulting matrix M has the form:
⎡ εxx εyy εzz 0 0 0 0 0 0 ⎤
⎢ 0 εxx 0 εyy εzz 0 0 0 0 ⎥
⎢ 0 0 εxx 0 εyy εzz 0 0 0 ⎥
⎢ 0 0 0 0 0 0 2εyz 0 0 ⎥
⎢ 0 0 0 0 0 0 0 2εxz 0 ⎥
⎣ 0 0 0 0 0 0 0 0 2εxy⎦
"""
if not isinstance(strains, torch.Tensor):
strains = torch.tensor(strains)
if strains.shape != (6,):
raise ValueError("Strains tensor must have shape (6,)")
# Unpack strain components
εxx, εyy, εzz, εyz, εxz, εxy = strains.unbind()
# Create the matrix using torch.zeros for proper device/dtype handling
matrix = torch.zeros((6, 9), dtype=strains.dtype, device=strains.device)
# First row - C11, C12, C13, C22, C23, C33, C44, C55, C66
matrix[0, 0] = εxx
matrix[0, 1] = εyy
matrix[0, 2] = εzz
# Second row
matrix[1, 1] = εxx
matrix[1, 3] = εyy
matrix[1, 4] = εzz
# Third row
matrix[2, 2] = εxx
matrix[2, 4] = εyy
matrix[2, 5] = εzz
# Fourth row
matrix[3, 6] = 2 * εyz
# Fifth row
matrix[4, 7] = 2 * εxz
# Sixth row
matrix[5, 8] = 2 * εxy
return matrix
[docs]
def trigonal_symmetry(strains: torch.Tensor) -> torch.Tensor:
"""Generate equation matrix for trigonal crystal symmetry.
Constructs the stress-strain relationship matrix for trigonal symmetry,
which has 7 independent elastic constants: C11, C12, C13, C14, C15, C33, C44.
Matrix construction follows the standard form for trigonal symmetry.
Args:
strains: Tensor of shape (6,) containing strain components
[εxx, εyy, εzz, εyz, εxz, εxy]
Returns:
torch.Tensor: Matrix of shape (6, 7) where columns correspond to
coefficients for C11, C12, C13, C14, C15, C33, C44
Notes:
The resulting matrix M has the form:
⎡ εxx εyy εzz 2εyz 2εxz 0 0 ⎤
⎢ εyy εxx εzz -2εyz -2εxz 0 0 ⎥
⎢ 0 0 εxx+εyy 0 0 εzz 0 ⎥
⎢ 0 0 0 εxx-εyy -2εxy 0 2εyz ⎥
⎢ 0 0 0 2εxy εxx-εyy 0 2εxz ⎥
⎣ εxy -εxy 0 2εxz -2εyz 0 0 ⎦
"""
if not isinstance(strains, torch.Tensor):
strains = torch.tensor(strains)
if strains.shape != (6,):
raise ValueError("Strains tensor must have shape (6,)")
# Unpack strain components
εxx, εyy, εzz, εyz, εxz, εxy = strains.unbind()
# Create the matrix using torch.zeros for proper device/dtype handling
matrix = torch.zeros((6, 7), dtype=strains.dtype, device=strains.device)
# First row
matrix[0, 0] = εxx
matrix[0, 1] = εyy
matrix[0, 2] = εzz
matrix[0, 3] = 2 * εyz
matrix[0, 4] = 2 * εxz
# Second row
matrix[1, 0] = εyy
matrix[1, 1] = εxx
matrix[1, 2] = εzz
matrix[1, 3] = -2 * εyz
matrix[1, 4] = -2 * εxz
# Third row
matrix[2, 2] = εxx + εyy
matrix[2, 5] = εzz
# Fourth row
matrix[3, 3] = εxx - εyy
matrix[3, 4] = -2 * εxy
matrix[3, 6] = 2 * εyz
# Fifth row
matrix[4, 3] = 2 * εxy
matrix[4, 4] = εxx - εyy
matrix[4, 6] = 2 * εxz
# Sixth row
matrix[5, 0] = εxy
matrix[5, 1] = -εxy
matrix[5, 3] = 2 * εxz
matrix[5, 4] = -2 * εyz
return matrix
[docs]
def hexagonal_symmetry(strains: torch.Tensor) -> torch.Tensor:
"""Generate equation matrix for hexagonal crystal symmetry.
Constructs the stress-strain relationship matrix for hexagonal symmetry,
which has 5 independent elastic constants: C11, C33, C12, C13, C44.
Note: C66 = (C11-C12)/2 is dependent.
Args:
strains: Tensor of shape (6,) containing strain components
[εxx, εyy, εzz, εyz, εxz, εxy]
Returns:
torch.Tensor: Matrix of shape (6, 5) where columns correspond to
coefficients for C11, C33, C12, C13, C44
Notes:
The resulting matrix M has the form:
⎡ εxx εyy εzz 0 0 ⎤
⎢ εyy εxx εzz 0 0 ⎥
⎢ 0 0 εxx+εyy εzz 0 ⎥
⎢ 0 0 0 0 2εyz⎥
⎢ 0 0 0 0 2εxz⎥
⎣ εxy -εxy 0 0 0 ⎦
"""
if not isinstance(strains, torch.Tensor):
strains = torch.tensor(strains)
if strains.shape != (6,):
raise ValueError("Strains tensor must have shape (6,)")
# Unpack strain components
εxx, εyy, εzz, εyz, εxz, εxy = strains.unbind()
# Create the matrix using torch.zeros for proper device/dtype handling
matrix = torch.zeros((6, 5), dtype=strains.dtype, device=strains.device)
# First row
matrix[0, 0] = εxx
matrix[0, 1] = εyy
matrix[0, 2] = εzz
# Second row
matrix[1, 0] = εyy
matrix[1, 1] = εxx
matrix[1, 2] = εzz
# Third row
matrix[2, 2] = εxx + εyy
matrix[2, 3] = εzz
# Fourth and fifth rows
matrix[3, 4] = 2 * εyz
matrix[4, 4] = 2 * εxz
# Sixth row
matrix[5, 0] = εxy
matrix[5, 1] = -εxy
return matrix
[docs]
def monoclinic_symmetry(strains: torch.Tensor) -> torch.Tensor:
"""Generate equation matrix for monoclinic crystal symmetry.
Constructs the stress-strain relationship matrix for monoclinic symmetry,
which has 13 independent elastic constants: C11, C12, C13, C15, C22, C23, C25,
C33, C35, C44, C46, C55, C66.
Args:
strains: Tensor of shape (6,) containing strain components
[εxx, εyy, εzz, εyz, εxz, εxy]
Returns:
torch.Tensor: Matrix of shape (6, 13) where columns correspond to
coefficients for the 13 independent constants in order:
[C11, C12, C13, C15, C22, C23, C25, C33, C35, C44, C46, C55, C66]
Notes:
For monoclinic symmetry with unique axis b (y), the matrix has the form:
⎡ εxx εyy εzz 2εxz 0 0 0 0 0 0 0 0 0 ⎤
⎢ 0 εxx 0 0 εyy εzz 2εxz 0 0 0 0 0 0 ⎥
⎢ 0 0 εxx 0 0 εyy 0 εzz 2εxz 0 0 0 0 ⎥
⎢ 0 0 0 0 0 0 0 0 0 2εyz 2εxy 0 0 ⎥
⎢ 0 0 0 εxx 0 0 εyy 0 εzz 0 0 2εxz 0 ⎥
⎣ 0 0 0 0 0 0 0 0 0 0 2εyz 0 2εxy⎦
"""
if not isinstance(strains, torch.Tensor):
strains = torch.tensor(strains)
if strains.shape != (6,):
raise ValueError("Strains tensor must have shape (6,)")
# Unpack strain components
εxx, εyy, εzz, εyz, εxz, εxy = strains.unbind()
# Create the matrix using torch.zeros for proper device/dtype handling
matrix = torch.zeros((6, 13), dtype=strains.dtype, device=strains.device)
# First row
matrix[0, 0] = εxx
matrix[0, 1] = εyy
matrix[0, 2] = εzz
matrix[0, 3] = 2 * εxz
# Second row
matrix[1, 1] = εxx
matrix[1, 4] = εyy
matrix[1, 5] = εzz
matrix[1, 6] = 2 * εxz
# Third row
matrix[2, 2] = εxx
matrix[2, 5] = εyy
matrix[2, 7] = εzz
matrix[2, 8] = 2 * εxz
# Fourth row
matrix[3, 9] = 2 * εyz
matrix[3, 10] = 2 * εxy
# Fifth row
matrix[4, 3] = εxx
matrix[4, 6] = εyy
matrix[4, 8] = εzz
matrix[4, 11] = 2 * εxz
# Sixth row
matrix[5, 10] = 2 * εyz
matrix[5, 12] = 2 * εxy
return matrix
[docs]
def triclinic_symmetry(strains: torch.Tensor) -> torch.Tensor:
"""Generate equation matrix for triclinic crystal symmetry.
Constructs the stress-strain relationship matrix for triclinic symmetry,
which has 21 independent elastic constants (the most general case).
Args:
strains: Tensor of shape (6,) containing strain components
[εxx, εyy, εzz, εyz, εxz, εxy]
Returns:
torch.Tensor: Matrix of shape (6, 21) where columns correspond to
all possible elastic constants in order:
[C11, C12, C13, C14, C15, C16,
C22, C23, C24, C25, C26,
C33, C34, C35, C36,
C44, C45, C46,
C55, C56,
C66]
"""
if not isinstance(strains, torch.Tensor):
strains = torch.tensor(strains)
if strains.shape != (6,):
raise ValueError("Strains tensor must have shape (6,)")
# Unpack strain components
εxx, εyy, εzz, εyz, εxz, εxy = strains.unbind()
# Create the matrix using torch.zeros for proper device/dtype handling
matrix = torch.zeros((6, 21), dtype=strains.dtype, device=strains.device)
# First row
matrix[0, 0] = εxx
matrix[0, 1] = εyy
matrix[0, 2] = εzz
matrix[0, 3] = 2 * εyz
matrix[0, 4] = 2 * εxz
matrix[0, 5] = 2 * εxy
# Second row
matrix[1, 1] = εxx
matrix[1, 6] = εyy
matrix[1, 7] = εzz
matrix[1, 8] = 2 * εyz
matrix[1, 9] = 2 * εxz
matrix[1, 10] = 2 * εxy
# Third row
matrix[2, 2] = εxx
matrix[2, 7] = εyy
matrix[2, 11] = εzz
matrix[2, 12] = 2 * εyz
matrix[2, 13] = 2 * εxz
matrix[2, 14] = 2 * εxy
# Fourth row
matrix[3, 3] = εxx
matrix[3, 8] = εyy
matrix[3, 12] = εzz
matrix[3, 15] = 2 * εyz
matrix[3, 16] = 2 * εxz
matrix[3, 17] = 2 * εxy
# Fifth row
matrix[4, 4] = εxx
matrix[4, 9] = εyy
matrix[4, 13] = εzz
matrix[4, 16] = 2 * εyz
matrix[4, 18] = 2 * εxz
matrix[4, 19] = 2 * εxy
# Sixth row
matrix[5, 5] = εxx
matrix[5, 10] = εyy
matrix[5, 14] = εzz
matrix[5, 17] = 2 * εyz
matrix[5, 19] = 2 * εxz
matrix[5, 20] = 2 * εxy
return matrix
[docs]
def get_strain(
deformed_state: SimState, reference_state: SimState | None = None
) -> torch.Tensor:
"""Calculate strain tensor in Voigt notation.
Computes the strain tensor as a 6-component vector following Voigt notation.
The calculation is performed relative to a reference (undeformed) state.
Args:
deformed_state: SimState containing the deformed configuration
reference_state: Optional reference (undeformed) state. If None,
uses deformed_state as reference.
Returns:
torch.Tensor: 6-component strain vector [εxx, εyy, εzz, εyz, εxz, εxy]
following Voigt notation
Notes:
The strain is computed as ε = (u + u^T)/2 where u = M^(-1)ΔM,
with M being the cell matrix and ΔM the cell difference.
Voigt notation mapping:
- ε[0] = εxx = u[0,0]
- ε[1] = εyy = u[1,1]
- ε[2] = εzz = u[2,2]
- ε[3] = εyz = u[2,1]
- ε[4] = εxz = u[2,0]
- ε[5] = εxy = u[1,0]
"""
dtype = deformed_state.positions.dtype
device = deformed_state.positions.device
if not isinstance(deformed_state, SimState):
raise TypeError("deformed_state must be an SimState")
# Use deformed state as reference if none provided
if reference_state is None:
reference_state = deformed_state
# Get cell matrices
deformed_cell = deformed_state.row_vector_cell.squeeze()
reference_cell = reference_state.row_vector_cell.squeeze()
# Calculate displacement gradient tensor: u = M^(-1)ΔM
cell_difference = deformed_cell - reference_cell
reference_inverse = torch.linalg.inv(reference_cell)
u = torch.matmul(reference_inverse, cell_difference)
# Compute symmetric strain tensor: ε = (u + u^T)/2
strain = (u + u.transpose(-2, -1)) / 2
# Convert to Voigt notation
return torch.tensor(
[
strain[0, 0], # εxx
strain[1, 1], # εyy
strain[2, 2], # εzz
strain[2, 1], # εyz
strain[2, 0], # εxz
strain[1, 0], # εxy
],
device=device,
dtype=dtype,
)
[docs]
def voigt_6_to_full_3x3_stress(stress_voigt: torch.Tensor) -> torch.Tensor:
"""Convert a 6-component stress vector in Voigt notation to a 3x3 matrix.
Args:
stress_voigt: Tensor of shape (..., 6) containing stress components
[σxx, σyy, σzz, σyz, σxz, σxy] in Voigt notation
Returns:
torch.Tensor: Tensor of shape (..., 3, 3) containing the full stress matrix
"""
device = stress_voigt.device
dtype = stress_voigt.dtype
# Initialize 3x3 stress tensor
stress = torch.zeros((*stress_voigt.shape[:-1], 3, 3), device=device, dtype=dtype)
# Fill diagonal elements
stress[..., 0, 0] = stress_voigt[..., 0] # σxx
stress[..., 1, 1] = stress_voigt[..., 1] # σyy
stress[..., 2, 2] = stress_voigt[..., 2] # σzz
# Fill off-diagonal elements (symmetric)
stress[..., 2, 1] = stress[..., 1, 2] = stress_voigt[..., 3] # σyz
stress[..., 2, 0] = stress[..., 0, 2] = stress_voigt[..., 4] # σxz
stress[..., 1, 0] = stress[..., 0, 1] = stress_voigt[..., 5] # σxy
return stress
[docs]
def full_3x3_to_voigt_6_stress(stress: torch.Tensor) -> torch.Tensor:
"""Form a 6 component stress vector in Voigt notation from a 3x3 matrix.
Args:
stress: Tensor of shape (..., 3, 3) containing stress components
Returns:
torch.Tensor: 6-component stress vector [σxx, σyy, σzz, σyz, σxz, σxy]
following Voigt notation
"""
device = stress.device
dtype = stress.dtype
# Ensure the tensor is symmetric
stress = (stress + stress.transpose(-2, -1)) / 2
# Create the Voigt vector while preserving batch dimensions
return torch.stack(
[
stress[..., 0, 0], # σxx
stress[..., 1, 1], # σyy
stress[..., 2, 2], # σzz
stress[..., 2, 1], # σyz
stress[..., 2, 0], # σxz
stress[..., 1, 0], # σxy
],
dim=-1,
).to(device=device, dtype=dtype)
[docs]
def get_elastic_coeffs(
state: SimState,
deformed_states: list[SimState],
stresses: torch.Tensor,
base_pressure: torch.Tensor,
bravais_type: BravaisType,
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, int, torch.Tensor]]:
"""Calculate elastic tensor from stress-strain relationships.
Computes the elastic tensor by fitting stress-strain relations to a set of
linear equations built from crystal symmetry and deformation data.
Args:
state: SimState containing reference structure
deformed_states: List of deformed SimStates with calculated stresses
stresses: Tensor of shape (n_states, 6) containing stress components for each
state
base_pressure: Reference pressure of the base state
bravais_type: Crystal system (BravaisType enum)
Returns:
tuple containing:
- torch.Tensor: Cij elastic constants
- tuple containing:
- torch.Tensor: Bij Birch coefficients
- torch.Tensor: Residuals from least squares fit
- int: Rank of solution
- torch.Tensor: Singular values
Notes:
The elastic tensor is calculated as Cij = Bij - P, where:
- Bij are the Birch coefficients from least squares fitting
- P is a pressure-dependent correction specific to each symmetry
Stress and strain are related by: σᵢ = Σⱼ Cᵢⱼ εⱼ
"""
# Deformation rules for different Bravais lattices
deformation_rules: dict[BravaisType, DeformationRule] = {
BravaisType.CUBIC: DeformationRule([0, 3], regular_symmetry),
BravaisType.HEXAGONAL: DeformationRule([0, 2, 3, 5], hexagonal_symmetry),
BravaisType.TRIGONAL: DeformationRule([0, 2, 3, 4, 5], trigonal_symmetry),
BravaisType.TETRAGONAL: DeformationRule([0, 2, 3, 4, 5], tetragonal_symmetry),
BravaisType.ORTHORHOMBIC: DeformationRule(
[0, 1, 2, 3, 4, 5], orthorhombic_symmetry
),
BravaisType.MONOCLINIC: DeformationRule([0, 1, 2, 3, 4, 5], monoclinic_symmetry),
BravaisType.TRICLINIC: DeformationRule([0, 1, 2, 3, 4, 5], triclinic_symmetry),
}
# Get symmetry handler for this Bravais lattice
rule = deformation_rules[bravais_type]
symmetry_handler = rule.symmetry_handler
# Calculate strains for all deformed states
strains = []
for deformed in deformed_states:
strain = get_strain(deformed, reference_state=state)
strains.append(strain)
# Remove ambient pressure from stresses
p_correction = torch.tensor(
[base_pressure] * 3 + [0] * 3, device=stresses.device, dtype=stresses.dtype
)
corrected_stresses = stresses - p_correction
# Build equation matrix using symmetry
eq_matrices = [symmetry_handler(strain) for strain in strains]
eq_matrix = torch.stack(eq_matrices)
# Reshape for least squares solving
eq_matrix = eq_matrix.reshape(-1, eq_matrix.shape[-1])
stress_vector = corrected_stresses.reshape(-1)
# Solve least squares problem
Bij, residuals, rank, singular_values = torch.linalg.lstsq(eq_matrix, stress_vector)
# Calculate elastic constants with pressure correction
p = base_pressure
pressure_corrections = {
BravaisType.CUBIC: torch.tensor([-p, p, -p]),
BravaisType.HEXAGONAL: torch.tensor([-p, -p, p, p, -p]),
BravaisType.TRIGONAL: torch.tensor([-p, -p, p, p, p, p, -p]),
BravaisType.TETRAGONAL: torch.tensor([-p, -p, p, p, -p, -p, -p]),
BravaisType.ORTHORHOMBIC: torch.tensor([-p, -p, -p, p, p, p, -p, -p, -p]),
BravaisType.MONOCLINIC: torch.tensor(
[-p, -p, -p, p, p, p, -p, -p, -p, p, p, p, p]
),
BravaisType.TRICLINIC: torch.tensor(
[
-p,
p,
p,
p,
p,
p, # C11-C16
-p,
p,
p,
p,
p, # C22-C26
-p,
p,
p,
p, # C33-C36
-p,
p,
p, # C44-C46
-p,
p, # C55-C56
-p, # C66
]
),
}
# Apply pressure correction for the specific symmetry
Cij = Bij - pressure_corrections[bravais_type].to(Bij.device)
return Cij, (Bij, residuals, rank, singular_values)
[docs]
def get_elastic_tensor_from_coeffs( # noqa: C901, PLR0915
Cij: torch.Tensor,
bravais_type: BravaisType,
) -> torch.Tensor:
"""Convert the symmetry-reduced elastic constants to full 6x6 elastic tensor.
Args:
Cij: Tensor containing independent elastic constants for the given symmetry
bravais_type: Crystal system determining the symmetry rules
Returns:
torch.Tensor: Full 6x6 elastic tensor with all components
Notes:
The mapping follows Voigt notation where:
1 = xx, 2 = yy, 3 = zz, 4 = yz, 5 = xz, 6 = xy
The number of independent constants varies by symmetry:
- Cubic: 3 (C11, C12, C44)
- Hexagonal: 5 (C11, C12, C13, C33, C44)
- Trigonal: 6 (C11, C12, C13, C14, C33, C44)
- Tetragonal: 7 (C11, C12, C13, C16, C33, C44, C66)
- Orthorhombic: 9 (C11, C22, C33, C12, C13, C23, C44, C55, C66)
- Monoclinic: 13 constants (C11, C22, C33, C12, C13, C23, C44, C55,
C66, C15, C25, C35, C46)
- Triclinic: 21 constants
"""
# Initialize full tensor
C = torch.zeros((6, 6), dtype=Cij.dtype, device=Cij.device)
if bravais_type == BravaisType.TRICLINIC:
if len(Cij) != 21:
raise ValueError(
f"Triclinic symmetry requires 21 independent constants, "
f"but got {len(Cij)}"
)
C = torch.zeros((6, 6), dtype=Cij.dtype, device=Cij.device)
idx = 0
for i in range(6):
for j in range(i, 6):
C[i, j] = C[j, i] = Cij[idx]
idx += 1
elif bravais_type == BravaisType.CUBIC:
C11, C12, C44 = Cij
diag = torch.tensor([C11, C11, C11, C44, C44, C44])
C.diagonal().copy_(diag)
C[0, 1] = C[1, 0] = C[0, 2] = C[2, 0] = C[1, 2] = C[2, 1] = C12
elif bravais_type == BravaisType.HEXAGONAL:
C11, C12, C13, C33, C44 = Cij
C.diagonal().copy_(torch.tensor([C11, C11, C33, C44, C44, (C11 - C12) / 2]))
C[0, 1] = C[1, 0] = C12
C[0, 2] = C[2, 0] = C[1, 2] = C[2, 1] = C13
elif bravais_type == BravaisType.TRIGONAL:
C11, C12, C13, C14, C15, C33, C44 = Cij
C.diagonal().copy_(torch.tensor([C11, C11, C33, C44, C44, (C11 - C12) / 2]))
C[0, 1] = C[1, 0] = C12
C[0, 2] = C[2, 0] = C[1, 2] = C[2, 1] = C13
C[0, 3] = C[3, 0] = C14
C[0, 4] = C[4, 0] = C15
C[1, 3] = C[3, 1] = -C14
C[1, 4] = C[4, 1] = -C15
C[3, 5] = C[5, 3] = -C15
C[4, 5] = C[5, 4] = C14
elif bravais_type == BravaisType.TETRAGONAL:
C11, C12, C13, C16, C33, C44, C66 = Cij
C.diagonal().copy_(torch.tensor([C11, C11, C33, C44, C44, C66]))
C[0, 1] = C[1, 0] = C12
C[0, 2] = C[2, 0] = C[1, 2] = C[2, 1] = C13
C[0, 5] = C[5, 0] = C16
C[1, 5] = C[5, 1] = -C16
elif bravais_type == BravaisType.ORTHORHOMBIC:
C11, C12, C13, C22, C23, C33, C44, C55, C66 = Cij
C.diagonal().copy_(torch.tensor([C11, C22, C33, C44, C55, C66]))
C[0, 1] = C[1, 0] = C12
C[0, 2] = C[2, 0] = C13
C[1, 2] = C[2, 1] = C23
elif bravais_type == BravaisType.MONOCLINIC:
C11, C12, C13, C15, C22, C23, C25, C33, C35, C44, C46, C55, C66 = Cij
C.diagonal().copy_(torch.tensor([C11, C22, C33, C44, C55, C66]))
C[0, 1] = C[1, 0] = C12
C[0, 2] = C[2, 0] = C13
C[0, 4] = C[4, 0] = C15
C[1, 2] = C[2, 1] = C23
C[1, 4] = C[4, 1] = C25
C[2, 4] = C[4, 2] = C35
C[3, 5] = C[5, 3] = C46
return C
[docs]
def calculate_elastic_tensor(
model: torch.nn.Module,
*,
state: SimState,
bravais_type: BravaisType = BravaisType.TRICLINIC,
max_strain_normal: float = 0.01,
max_strain_shear: float = 0.06,
n_deform: int = 5,
) -> torch.Tensor:
"""Calculate the elastic tensor of a structure.
Args:
model: Model to use for stress calculation
state: SimState containing the reference structure
bravais_type: Bravais type of the structure
max_strain_normal: Maximum normal strain
max_strain_shear: Maximum shear strain
n_deform: Number of deformations
Returns:
torch.Tensor: Elastic tensor
"""
device = state.positions.device
dtype = state.positions.dtype
# Calculate deformations for the bravais type
deformations = get_elementary_deformations(
state,
n_deform=n_deform,
max_strain_normal=max_strain_normal,
max_strain_shear=max_strain_shear,
bravais_type=bravais_type,
)
# Calculate stresses for deformations
ref_pressure = -torch.trace(state.stress.squeeze()) / 3
stresses = torch.zeros((len(deformations), 6), device=device, dtype=dtype)
for i, deformation in enumerate(deformations):
result = model(deformation)
stresses[i] = full_3x3_to_voigt_6_stress(result["stress"].squeeze())
# Calculate elastic tensor
C_ij, Res = get_elastic_coeffs(
state, deformations, stresses, ref_pressure, bravais_type
)
C = get_elastic_tensor_from_coeffs(C_ij, bravais_type)
return C # noqa: RET504
[docs]
def calculate_elastic_moduli(C: torch.Tensor) -> tuple[float, float, float, float]:
"""Calculate elastic moduli from the elastic tensor.
Args:
C: Elastic tensor (6x6)
Returns:
tuple: Four Voigt-Reuss-Hill averaged elastic moduli in order:
- Bulk modulus (K_VRH)
- Shear modulus (G_VRH)
- Poisson's ratio (v_VRH), dimensionless
- Pugh's ratio (K_VRH/G_VRH), dimensionless
"""
# Ensure we're working with a tensor
if not isinstance(C, torch.Tensor):
C = torch.tensor(C)
# Components of the elastic tensor
C11, C22, C33 = C[0, 0], C[1, 1], C[2, 2]
C12, C23, C31 = C[0, 1], C[1, 2], C[2, 0]
C44, C55, C66 = C[3, 3], C[4, 4], C[5, 5]
# Calculate compliance tensor
S = torch.linalg.inv(C)
S11, S22, S33 = S[0, 0], S[1, 1], S[2, 2]
S12, S23, S31 = S[0, 1], S[1, 2], S[2, 0]
S44, S55, S66 = S[3, 3], S[4, 4], S[5, 5]
# Voigt averaging (upper bound)
K_V = (1 / 9) * ((C11 + C22 + C33) + 2 * (C12 + C23 + C31))
G_V = (1 / 15) * ((C11 + C22 + C33) - (C12 + C23 + C31) + 3 * (C44 + C55 + C66))
# Reuss averaging (lower bound)
K_R = 1 / ((S11 + S22 + S33) + 2 * (S12 + S23 + S31))
G_R = 15 / (4 * (S11 + S22 + S33) - 4 * (S12 + S23 + S31) + 3 * (S44 + S55 + S66))
# Voigt-Reuss-Hill averaging
K_VRH = (K_V + K_R) / 2
G_VRH = (G_V + G_R) / 2
# Poisson's ratio (VRH)
v_VRH = (3 * K_VRH - 2 * G_VRH) / (6 * K_VRH + 2 * G_VRH)
# Pugh's ratio (VRH)
pugh_ratio_VRH = K_VRH / G_VRH
# Convert to Python floats for the return values
return (
float(K_VRH.item()),
float(G_VRH.item()),
float(v_VRH.item()),
float(pugh_ratio_VRH.item()),
)