Source code for torch_sim.models.particle_life

"""Particle life model for computing forces between particles."""

import torch

import torch_sim as ts
from torch_sim import transforms
from torch_sim.models.interface import ModelInterface
from torch_sim.neighbors import vesin_nl_ts
from torch_sim.typing import StateDict


DEFAULT_BETA = torch.tensor(0.3)
DEFAULT_SIGMA = torch.tensor(1.0)


[docs] def asymmetric_particle_pair_force( dr: torch.Tensor, A: torch.Tensor, beta: torch.Tensor = DEFAULT_BETA, sigma: torch.Tensor = DEFAULT_SIGMA, ) -> torch.Tensor: """Asymmetric interaction between particles. Args: dr: A tensor of shape [n, m] of pairwise distances between particles. A: Interaction scale. Either a float scalar or a tensor of shape [n, m]. beta: Inner radius of the interaction. Either a float scalar or tensor of shape [n, m]. sigma: Outer radius of the interaction. Either a float scalar or tensor of shape [n, m]. Returns: torch.Tensor: Energies with shape [n, m]. """ inner_mask = dr < beta outer_mask = (dr < sigma) & (dr > beta) def inner_force_fn(dr: torch.Tensor) -> torch.Tensor: return dr / beta - 1 def intermediate_force_fn(dr: torch.Tensor) -> torch.Tensor: return A * (1 - torch.abs(2 * dr - 1 - beta) / (1 - beta)) return torch.where(inner_mask, inner_force_fn(dr), 0) + torch.where( outer_mask, intermediate_force_fn(dr), 0, )
[docs] def asymmetric_particle_pair_force_jit( dr: torch.Tensor, A: torch.Tensor, beta: torch.Tensor = DEFAULT_BETA, sigma: torch.Tensor = DEFAULT_SIGMA, ) -> torch.Tensor: """Asymmetric interaction between particles. Args: dr: A tensor of shape [n, m] of pairwise distances between particles. A: Interaction scale. Either a float scalar or a tensor of shape [n, m]. beta: Inner radius of the interaction. Either a float scalar or tensor of shape [n, m]. sigma: Outer radius of the interaction. Either a float scalar or tensor of shape [n, m]. Returns: torch.Tensor: Energies with shape [n, m]. """ inner_mask = dr < beta outer_mask = (dr < sigma) & (dr > beta) # Calculate inner forces directly inner_forces = torch.where(inner_mask, dr / beta - 1, torch.zeros_like(dr)) # Calculate outer forces directly outer_forces = torch.where( outer_mask, A * (1 - torch.abs(2 * dr - 1 - beta) / (1 - beta)), torch.zeros_like(dr), ) return inner_forces + outer_forces
[docs] class ParticleLifeModel(torch.nn.Module, ModelInterface): """Calculator for asymmetric particle interaction. This model implements an asymmetric interaction between particles based on distance-dependent forces. The interaction is defined by three parameters: sigma, epsilon, and beta. """ def __init__( self, sigma: float = 1.0, epsilon: float = 1.0, device: torch.device | None = None, dtype: torch.dtype = torch.float32, *, # Force keyword-only arguments compute_forces: bool = False, compute_stress: bool = False, per_atom_energies: bool = False, per_atom_stresses: bool = False, use_neighbor_list: bool = True, cutoff: float | None = None, ) -> None: """Initialize the calculator.""" super().__init__() self._device = device or torch.device("cpu") self._dtype = dtype self._compute_forces = compute_forces self._compute_stress = compute_stress self._per_atom_energies = per_atom_energies self._per_atom_stresses = per_atom_stresses self.use_neighbor_list = use_neighbor_list # Convert parameters to tensors self.sigma = torch.tensor(sigma, dtype=self.dtype, device=self.device) self.cutoff = torch.tensor( cutoff or 2.5 * sigma, dtype=self.dtype, device=self.device ) self.epsilon = torch.tensor(epsilon, dtype=self.dtype, device=self.device)
[docs] def unbatched_forward(self, state: ts.SimState) -> dict[str, torch.Tensor]: """Compute energies and forces for a single unbatched system. Internal implementation that processes a single, non-batched simulation state. This method handles the core computations of pair interactions, neighbor lists, and property calculations. Args: state: Single, non-batched simulation state containing atomic positions, cell vectors, and other system information. Returns: A dictionary containing the energy, forces, and stresses """ # Extract required data from input if isinstance(state, dict): state = ts.SimState(**state, masses=torch.ones_like(state["positions"])) positions = state.positions cell = state.row_vector_cell pbc = state.pbc if cell.dim() == 3: # Check if there is an extra batch dimension cell = cell.squeeze(0) # Squeeze the first dimension if self.use_neighbor_list: # Get neighbor list using wrapping_nl mapping, shifts = vesin_nl_ts( positions=positions, cell=cell, pbc=pbc, cutoff=float(self.cutoff), sort_id=False, ) # Get displacements using neighbor list dr_vec, distances = transforms.get_pair_displacements( positions=positions, cell=cell, pbc=pbc, pairs=mapping, shifts=shifts, ) else: # Get all pairwise displacements dr_vec, distances = transforms.get_pair_displacements( positions=positions, cell=cell, pbc=pbc, ) # Mask out self-interactions mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self.device) distances = distances.masked_fill(mask, float("inf")) # Apply cutoff mask = distances < self.cutoff # Get valid pairs - match neighbor list convention for pair order i, j = torch.where(mask) mapping = torch.stack([j, i]) # Changed from [j, i] to [i, j] # Get valid displacements and distances dr_vec = dr_vec[mask] distances = distances[mask] # Zero out energies beyond cutoff mask = distances < self.cutoff # Initialize results with total energy (sum/2 to avoid double counting) results = {"energy": 0.0} # Calculate forces and apply cutoff pair_forces = asymmetric_particle_pair_force_jit( distances, sigma=self.sigma, epsilon=self.epsilon ) pair_forces = torch.where(mask, pair_forces, torch.zeros_like(pair_forces)) # Project forces along displacement vectors force_vectors = (pair_forces / distances)[:, None] * dr_vec # Initialize forces tensor forces = torch.zeros_like(state.positions) # Add force contributions (f_ij on i, -f_ij on j) forces.index_add_(0, mapping[0], -force_vectors) forces.index_add_(0, mapping[1], force_vectors) results["forces"] = forces return results
[docs] def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: """Compute particle life energies and forces for a system. Main entry point for particle life calculations that handles batched states by dispatching each batch to the unbatched implementation and combining results. Args: state: Input state containing atomic positions, cell vectors, and other system information. Can be a SimState object or a dictionary with the same keys. Returns: dict[str, torch.Tensor]: Computed properties: - "energy": Potential energy with shape [n_batches] - "forces": Atomic forces with shape [n_atoms, 3] (if compute_forces=True) - "stress": Stress tensor with shape [n_batches, 3, 3] (if compute_stress=True) - "energies": Per-atom energies with shape [n_atoms] (if per_atom_energies=True) - "stresses": Per-atom stresses with shape [n_atoms, 3, 3] (if per_atom_stresses=True) Raises: ValueError: If batch cannot be inferred for multi-cell systems. """ if isinstance(state, dict): state = ts.SimState(**state, masses=torch.ones_like(state["positions"])) if state.batch is None and state.cell.shape[0] > 1: raise ValueError("Batch can only be inferred for batch size 1.") outputs = [self.unbatched_forward(state[i]) for i in range(state.n_batches)] properties = outputs[0] # we always return tensors # per atom properties are returned as (atoms, ...) tensors # global properties are returned as shape (..., n) tensors results = {} for key in ("stress", "energy"): if key in properties: results[key] = torch.stack([out[key] for out in outputs]) for key in ("forces", "energies", "stresses"): if key in properties: results[key] = torch.cat([out[key] for out in outputs], dim=0) return results