Source code for torch_sim.models.lennard_jones

"""Classical pairwise interatomic potential model.

This module implements the Lennard-Jones potential for molecular dynamics simulations.
It provides efficient calculation of energies, forces, and stresses based on the
classic 12-6 potential function. The implementation supports both full pairwise
calculations and neighbor list-based optimizations.

Example::

    # Create a Lennard-Jones model with default parameters
    model = LennardJonesModel(device=torch.device("cuda"))

    # Create a model with custom parameters
    model = LennardJonesModel(
        sigma=3.405,  # Angstroms
        epsilon=0.01032,  # eV
        cutoff=10.0,  # Angstroms
        compute_stress=True,
    )

    # Calculate properties for a simulation state
    output = model(sim_state)
    energy = output["energy"]
    forces = output["forces"]
"""

import torch

from torch_sim.models.interface import ModelInterface
from torch_sim.neighbors import vesin_nl_ts
from torch_sim.state import SimState
from torch_sim.transforms import get_pair_displacements
from torch_sim.typing import StateDict
from torch_sim.unbatched.models.lennard_jones import (
    lennard_jones_pair,
    lennard_jones_pair_force,
)


# Default parameter values defined at module level
DEFAULT_SIGMA = torch.tensor(1.0)
DEFAULT_EPSILON = torch.tensor(1.0)


[docs] class LennardJonesModel(torch.nn.Module, ModelInterface): """Lennard-Jones potential energy and force calculator. Implements the Lennard-Jones 12-6 potential for molecular dynamics simulations. This model calculates pairwise interactions between atoms and supports either full pairwise calculation or neighbor list-based optimization for efficiency. Attributes: sigma (torch.Tensor): Length parameter controlling particle size/repulsion distance. epsilon (torch.Tensor): Energy parameter controlling interaction strength. cutoff (torch.Tensor): Distance cutoff for truncating potential calculation. device (torch.device): Device where calculations are performed. dtype (torch.dtype): Data type used for calculations. compute_forces (bool): Whether to compute atomic forces. compute_stress (bool): Whether to compute stress tensor. per_atom_energies (bool): Whether to compute per-atom energy decomposition. per_atom_stresses (bool): Whether to compute per-atom stress decomposition. use_neighbor_list (bool): Whether to use neighbor list optimization. Example:: # Basic usage with default parameters lj_model = LennardJonesModel(device=torch.device("cuda")) results = lj_model(sim_state) # Custom parameterization for Argon ar_model = LennardJonesModel( sigma=3.405, # Å epsilon=0.0104, # eV cutoff=8.5, # Å compute_stress=True, ) """ def __init__( self, sigma: float = 1.0, epsilon: float = 1.0, device: torch.device | None = None, dtype: torch.dtype = torch.float32, *, # Force keyword-only arguments compute_forces: bool = True, compute_stress: bool = False, per_atom_energies: bool = False, per_atom_stresses: bool = False, use_neighbor_list: bool = True, cutoff: float | None = None, ) -> None: """Initialize the Lennard-Jones potential calculator. Creates a model with specified interaction parameters and computational flags. The model can be configured to compute different properties (forces, stresses) and use different optimization strategies. Args: sigma (float): Length parameter of the Lennard-Jones potential in distance units. Controls the size of particles. Defaults to 1.0. epsilon (float): Energy parameter of the Lennard-Jones potential in energy units. Controls the strength of the interaction. Defaults to 1.0. device (torch.device | None): Device to run computations on. If None, uses CPU. Defaults to None. dtype (torch.dtype): Data type for calculations. Defaults to torch.float32. compute_forces (bool): Whether to compute forces. Defaults to True. compute_stress (bool): Whether to compute stress tensor. Defaults to False. per_atom_energies (bool): Whether to compute per-atom energy decomposition. Defaults to False. per_atom_stresses (bool): Whether to compute per-atom stress decomposition. Defaults to False. use_neighbor_list (bool): Whether to use a neighbor list for optimization. Significantly faster for large systems. Defaults to True. cutoff (float | None): Cutoff distance for interactions in distance units. If None, uses 2.5*sigma. Defaults to None. Example:: # Model with custom parameters model = LennardJonesModel( sigma=3.405, epsilon=0.01032, device=torch.device("cuda"), dtype=torch.float64, compute_stress=True, per_atom_energies=True, cutoff=10.0, ) """ super().__init__() self._device = device or torch.device("cpu") self._dtype = dtype self._compute_forces = compute_forces self._compute_stress = compute_stress self.per_atom_energies = per_atom_energies self.per_atom_stresses = per_atom_stresses self.use_neighbor_list = use_neighbor_list # Convert parameters to tensors self.sigma = torch.tensor(sigma, dtype=dtype, device=self.device) self.cutoff = torch.tensor(cutoff or 2.5 * sigma, dtype=dtype, device=self.device) self.epsilon = torch.tensor(epsilon, dtype=dtype, device=self.device)
[docs] def unbatched_forward( self, state: SimState, ) -> dict[str, torch.Tensor]: """Compute Lennard-Jones properties for a single unbatched system. Internal implementation that processes a single, non-batched simulation state. This method handles the core computations of pair interactions, neighbor lists, and property calculations. Args: state (SimState): Single, non-batched simulation state containing atomic positions, cell vectors, and other system information. Returns: dict[str, torch.Tensor]: Dictionary of computed properties: - "energy": Total potential energy (scalar) - "forces": Atomic forces with shape [n_atoms, 3] (if compute_forces=True) - "stress": Stress tensor with shape [3, 3] (if compute_stress=True) - "energies": Per-atom energies with shape [n_atoms] (if per_atom_energies=True) - "stresses": Per-atom stresses with shape [n_atoms, 3, 3] (if per_atom_stresses=True) Notes: This method handles two different approaches: 1. Neighbor list approach: Efficient for larger systems 2. Full pairwise calculation: Better for small systems The implementation applies cutoff distance to both approaches for consistency. """ if not isinstance(state, SimState): state = SimState(**state) positions = state.positions cell = state.row_vector_cell cell = cell.squeeze() pbc = state.pbc if self.use_neighbor_list: # Get neighbor list using vesin_nl_ts mapping, shifts = vesin_nl_ts( positions=positions, cell=cell, pbc=pbc, cutoff=self.cutoff, sort_id=False, ) # Get displacements using neighbor list dr_vec, distances = get_pair_displacements( positions=positions, cell=cell, pbc=pbc, pairs=mapping, shifts=shifts, ) else: # Get all pairwise displacements dr_vec, distances = get_pair_displacements( positions=positions, cell=cell, pbc=pbc, ) # Mask out self-interactions mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self.device) distances = distances.masked_fill(mask, float("inf")) # Apply cutoff mask = distances < self.cutoff # Get valid pairs - match neighbor list convention for pair order i, j = torch.where(mask) mapping = torch.stack([j, i]) # Get valid displacements and distances dr_vec = dr_vec[mask] distances = distances[mask] # Calculate pair energies and apply cutoff pair_energies = lennard_jones_pair( distances, sigma=self.sigma, epsilon=self.epsilon ) # Zero out energies beyond cutoff mask = distances < self.cutoff pair_energies = torch.where(mask, pair_energies, torch.zeros_like(pair_energies)) # Initialize results with total energy (sum/2 to avoid double counting) results = {"energy": 0.5 * pair_energies.sum()} if self.per_atom_energies: atom_energies = torch.zeros( positions.shape[0], dtype=self.dtype, device=self.device ) # Each atom gets half of the pair energy atom_energies.index_add_(0, mapping[0], 0.5 * pair_energies) atom_energies.index_add_(0, mapping[1], 0.5 * pair_energies) results["energies"] = atom_energies if self.compute_forces or self.compute_stress: # Calculate forces and apply cutoff pair_forces = lennard_jones_pair_force( distances, sigma=self.sigma, epsilon=self.epsilon ) pair_forces = torch.where(mask, pair_forces, torch.zeros_like(pair_forces)) # Project forces along displacement vectors force_vectors = (pair_forces / distances)[:, None] * dr_vec if self.compute_forces: # Initialize forces tensor forces = torch.zeros_like(positions) # Add force contributions (f_ij on i, -f_ij on j) forces.index_add_(0, mapping[0], -force_vectors) forces.index_add_(0, mapping[1], force_vectors) results["forces"] = forces if self.compute_stress and cell is not None: # Compute stress tensor stress_per_pair = torch.einsum("...i,...j->...ij", dr_vec, force_vectors) volume = torch.abs(torch.linalg.det(cell)) results["stress"] = -stress_per_pair.sum(dim=0) / volume if self.per_atom_stresses: atom_stresses = torch.zeros( (state.positions.shape[0], 3, 3), dtype=self.dtype, device=self.device, ) atom_stresses.index_add_(0, mapping[0], -0.5 * stress_per_pair) atom_stresses.index_add_(0, mapping[1], -0.5 * stress_per_pair) results["stresses"] = atom_stresses / volume return results
[docs] def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: """Compute Lennard-Jones energies, forces, and stresses for a system. Main entry point for Lennard-Jones calculations that handles batched states by dispatching each batch to the unbatched implementation and combining results. Args: state (SimState | StateDict): Input state containing atomic positions, cell vectors, and other system information. Can be a SimState object or a dictionary with the same keys. Returns: dict[str, torch.Tensor]: Dictionary of computed properties: - "energy": Potential energy with shape [n_batches] - "forces": Atomic forces with shape [n_atoms, 3] (if compute_forces=True) - "stress": Stress tensor with shape [n_batches, 3, 3] (if compute_stress=True) - "energies": Per-atom energies with shape [n_atoms] (if per_atom_energies=True) - "stresses": Per-atom stresses with shape [n_atoms, 3, 3] (if per_atom_stresses=True) Raises: ValueError: If batch cannot be inferred for multi-cell systems. Example:: # Compute properties for a simulation state model = LennardJonesModel(compute_stress=True) results = model(sim_state) energy = results["energy"] # Shape: [n_batches] forces = results["forces"] # Shape: [n_atoms, 3] stress = results["stress"] # Shape: [n_batches, 3, 3] energies = results["energies"] # Shape: [n_atoms] stresses = results["stresses"] # Shape: [n_atoms, 3, 3] """ if isinstance(state, dict): state = SimState(**state, masses=torch.ones_like(state["positions"])) if state.batch is None and state.cell.shape[0] > 1: raise ValueError("Batch can only be inferred for batch size 1.") outputs = [self.unbatched_forward(state[i]) for i in range(state.n_batches)] properties = outputs[0] # we always return tensors # per atom properties are returned as (atoms, ...) tensors # global properties are returned as shape (..., n) tensors results = {} for key in ("stress", "energy"): if key in properties: results[key] = torch.stack([out[key] for out in outputs]) for key in ("forces", "energies", "stresses"): if key in properties: results[key] = torch.cat([out[key] for out in outputs], dim=0) return results