Source code for torch_sim.quantities

"""Functions for computing physical quantities."""

import torch

from torch_sim.state import SimState
from torch_sim.units import MetalUnits


# @torch.jit.script
[docs] def count_dof(tensor: torch.Tensor) -> int: """Count the degrees of freedom in the system. Args: tensor: Tensor to count the degrees of freedom in Returns: Number of degrees of freedom """ return tensor.numel()
# @torch.jit.script
[docs] def calc_kT( # noqa: N802 momenta: torch.Tensor, masses: torch.Tensor, velocities: torch.Tensor | None = None, batch: torch.Tensor | None = None, ) -> torch.Tensor: """Calculate temperature in energy units from momenta/velocities and masses. Args: momenta (torch.Tensor): Particle momenta, shape (n_particles, n_dim) masses (torch.Tensor): Particle masses, shape (n_particles,) velocities (torch.Tensor | None): Particle velocities, shape (n_particles, n_dim) batch (torch.Tensor | None): Optional tensor indicating batch membership of each particle Returns: torch.Tensor: Scalar temperature value """ if momenta is not None and velocities is not None: raise ValueError("Must pass either momenta or velocities, not both") if momenta is None and velocities is None: raise ValueError("Must pass either momenta or velocities") if momenta is None: # If velocity provided, calculate mv^2 squared_term = (velocities**2) * masses.unsqueeze(-1) else: # If momentum provided, calculate v^2 = p^2/m^2 squared_term = (momenta**2) / masses.unsqueeze(-1) if batch is None: # Count total degrees of freedom dof = count_dof(squared_term) return torch.sum(squared_term) / dof # Sum squared terms for each batch flattened_squared = torch.sum(squared_term, dim=-1) # Count degrees of freedom per batch batch_sizes = torch.bincount(batch) dof_per_batch = batch_sizes * squared_term.shape[-1] # multiply by n_dimensions # Calculate temperature per batch batch_sums = torch.segment_reduce( flattened_squared, reduce="sum", lengths=batch_sizes ) return batch_sums / dof_per_batch
[docs] def calc_temperature( momenta: torch.Tensor, masses: torch.Tensor, velocities: torch.Tensor | None = None, batch: torch.Tensor | None = None, units: object = MetalUnits.temperature, ) -> torch.Tensor: """Calculate temperature from momenta/velocities and masses. Args: momenta (torch.Tensor): Particle momenta, shape (n_particles, n_dim) masses (torch.Tensor): Particle masses, shape (n_particles,) velocities (torch.Tensor | None): Particle velocities, shape (n_particles, n_dim) batch (torch.Tensor | None): Optional tensor indicating batch membership of each particle units (object): Units to return the temperature in Returns: torch.Tensor: Temperature value in specified units """ return calc_kT(momenta, masses, velocities, batch) / units
# @torch.jit.script
[docs] def calc_kinetic_energy( momenta: torch.Tensor, masses: torch.Tensor, velocities: torch.Tensor | None = None, batch: torch.Tensor | None = None, ) -> torch.Tensor: """Computes the kinetic energy of a system. Args: momenta (torch.Tensor): Particle momenta, shape (n_particles, n_dim) masses (torch.Tensor): Particle masses, shape (n_particles,) velocities (torch.Tensor | None): Particle velocities, shape (n_particles, n_dim) batch (torch.Tensor | None): Optional tensor indicating batch membership of each particle Returns: If batch is None: Scalar tensor containing the total kinetic energy If batch is provided: Tensor of kinetic energies per batch """ if momenta is not None and velocities is not None: raise ValueError("Must pass either momenta or velocities, not both") if momenta is None and velocities is None: raise ValueError("Must pass either momenta or velocities") if momenta is None: # Using velocities squared_term = (velocities**2) * masses.unsqueeze(-1) else: # Using momentum squared_term = (momenta**2) / masses.unsqueeze(-1) if batch is None: return 0.5 * torch.sum(squared_term) flattened_squared = torch.sum(squared_term, dim=-1) return 0.5 * torch.segment_reduce( flattened_squared, reduce="sum", lengths=torch.bincount(batch) )
[docs] def batchwise_max_force(state: SimState) -> torch.Tensor: """Compute the maximum force per batch. Args: state (SimState): State to compute the maximum force per batch for. Returns: torch.Tensor: Maximum forces per batch """ batch_wise_max_force = torch.zeros( state.n_batches, device=state.device, dtype=state.dtype ) max_forces = state.forces.norm(dim=1) return batch_wise_max_force.scatter_reduce( dim=0, index=state.batch, src=max_forces, reduce="amax" )