"""Functions for computing physical quantities."""
from typing import cast
import torch
from torch_sim.state import SimState
from torch_sim.units import MetalUnits
# @torch.jit.script
[docs]
def count_dof(tensor: torch.Tensor) -> int:
"""Count the degrees of freedom in the system.
Args:
tensor: Tensor to count the degrees of freedom in
Returns:
Number of degrees of freedom
"""
return tensor.numel()
# @torch.jit.script
[docs]
def calc_kT( # noqa: N802
*,
masses: torch.Tensor,
momenta: torch.Tensor | None = None,
velocities: torch.Tensor | None = None,
system_idx: torch.Tensor | None = None,
) -> torch.Tensor:
"""Calculate temperature in energy units from momenta/velocities and masses.
Args:
momenta (torch.Tensor): Particle momenta, shape (n_particles, n_dim)
masses (torch.Tensor): Particle masses, shape (n_particles,)
velocities (torch.Tensor | None): Particle velocities, shape (n_particles, n_dim)
system_idx (torch.Tensor | None): Optional tensor indicating system membership of
each particle
Returns:
torch.Tensor: Scalar temperature value
"""
if not ((momenta is not None) ^ (velocities is not None)):
raise ValueError("Must pass either one of momenta or velocities")
if momenta is None:
# If velocity provided, calculate mv^2
velocities = cast("torch.Tensor", velocities)
squared_term = (velocities**2) * masses.unsqueeze(-1)
else:
# If momentum provided, calculate v^2 = p^2/m^2
squared_term = (momenta**2) / masses.unsqueeze(-1)
if system_idx is None:
# Count total degrees of freedom
dof = count_dof(squared_term)
return torch.sum(squared_term) / dof
# Sum squared terms for each system
flattened_squared = torch.sum(squared_term, dim=-1)
# Count degrees of freedom per system
system_sizes = torch.bincount(system_idx)
dof_per_system = system_sizes * squared_term.shape[-1] # multiply by n_dimensions
# Calculate temperature per system
system_sums = torch.segment_reduce(
flattened_squared, reduce="sum", lengths=system_sizes
)
return system_sums / dof_per_system
[docs]
def calc_temperature(
*,
masses: torch.Tensor,
momenta: torch.Tensor | None = None,
velocities: torch.Tensor | None = None,
system_idx: torch.Tensor | None = None,
units: MetalUnits = MetalUnits.temperature,
) -> torch.Tensor:
"""Calculate temperature from momenta/velocities and masses.
Args:
momenta (torch.Tensor): Particle momenta, shape (n_particles, n_dim)
masses (torch.Tensor): Particle masses, shape (n_particles,)
velocities (torch.Tensor | None): Particle velocities, shape (n_particles, n_dim)
system_idx (torch.Tensor | None): Optional tensor indicating system membership of
each particle
units (object): Units to return the temperature in
Returns:
torch.Tensor: Temperature value in specified units
"""
kT = calc_kT(
masses=masses, momenta=momenta, velocities=velocities, system_idx=system_idx
)
return kT / units
# @torch.jit.script
[docs]
def calc_kinetic_energy(
*,
masses: torch.Tensor,
momenta: torch.Tensor | None = None,
velocities: torch.Tensor | None = None,
system_idx: torch.Tensor | None = None,
) -> torch.Tensor:
"""Computes the kinetic energy of a system.
Args:
momenta (torch.Tensor): Particle momenta, shape (n_particles, n_dim)
masses (torch.Tensor): Particle masses, shape (n_particles,)
velocities (torch.Tensor | None): Particle velocities, shape (n_particles, n_dim)
system_idx (torch.Tensor | None): Optional tensor indicating system membership of
each particle
Returns:
If system_idx is None: Scalar tensor containing the total kinetic energy
If system_idx is provided: Tensor of kinetic energies per system
"""
if not ((momenta is not None) ^ (velocities is not None)):
raise ValueError("Must pass either one of momenta or velocities")
if momenta is None: # Using velocities
squared_term = (velocities**2) * masses.unsqueeze(-1)
else: # Using momenta
squared_term = (momenta**2) / masses.unsqueeze(-1)
if system_idx is None:
return 0.5 * torch.sum(squared_term)
flattened_squared = torch.sum(squared_term, dim=-1)
return 0.5 * torch.segment_reduce(
flattened_squared, reduce="sum", lengths=torch.bincount(system_idx)
)
[docs]
def get_pressure(
stress: torch.Tensor, kinetic_energy: torch.Tensor, volume: torch.Tensor, dim: int = 3
) -> torch.Tensor:
"""Compute the pressure from the stress tensor.
The stress tensor is defined as 1/volume * dU/de_ij
So the pressure is -1/volume * trace(dU/de_ij)
"""
return 1 / dim * ((2 * kinetic_energy / volume) - torch.einsum("...ii", stress))
[docs]
def systemwise_max_force(state: SimState) -> torch.Tensor:
"""Compute the maximum force per system.
Args:
state (SimState): State to compute the maximum force per system for.
Returns:
torch.Tensor: Maximum forces per system
"""
system_wise_max_force = torch.zeros(
state.n_systems, device=state.device, dtype=state.dtype
)
max_forces = state.forces.norm(dim=1)
return system_wise_max_force.scatter_reduce(
dim=0, index=state.system_idx, src=max_forces, reduce="amax"
)