Source code for torch_sim.models.morse

"""Anharmonic interatomic potential for molecular dynamics.

This module implements the Morse potential for molecular dynamics simulations.
The Morse potential provides a more realistic description of anharmonic bond
behavior than simple harmonic potentials, capturing bond breaking and formation.
It includes both energy and force calculations with support for neighbor lists.

Example::

    # Create a Morse model with default parameters
    model = MorseModel(device=torch.device("cuda"))

    # Calculate properties for a simulation state
    output = model(sim_state)
    energy = output["energy"]
    forces = output["forces"]

Notes:
    The Morse potential follows the form:
    V(r) = D_e * (1 - exp(-a(r-r_e)))^2

    Where:
    - D_e (epsilon) is the well depth (dissociation energy)
    - r_e (sigma) is the equilibrium bond distance
    - a (alpha) controls the width of the potential well
"""

import torch

from torch_sim.models.interface import ModelInterface
from torch_sim.neighbors import vesin_nl_ts
from torch_sim.state import SimState
from torch_sim.transforms import get_pair_displacements
from torch_sim.typing import StateDict
from torch_sim.unbatched.models.morse import morse_pair, morse_pair_force


[docs] class MorseModel(torch.nn.Module, ModelInterface): """Morse potential energy and force calculator. Implements the Morse potential for molecular dynamics simulations. This model is particularly useful for modeling covalent bonds as it can accurately describe bond stretching, breaking, and anharmonic behavior. Unlike the Lennard-Jones potential, Morse is often better for cases where accurate dissociation energy and bond dynamics are important. Attributes: sigma (torch.Tensor): Equilibrium bond length (r_e) in distance units. epsilon (torch.Tensor): Dissociation energy (D_e) in energy units. alpha (torch.Tensor): Parameter controlling the width/steepness of the potential. cutoff (torch.Tensor): Distance cutoff for truncating potential calculation. device (torch.device): Device where calculations are performed. dtype (torch.dtype): Data type used for calculations. compute_forces (bool): Whether to compute atomic forces. compute_stress (bool): Whether to compute stress tensor. per_atom_energies (bool): Whether to compute per-atom energy decomposition. per_atom_stresses (bool): Whether to compute per-atom stress decomposition. use_neighbor_list (bool): Whether to use neighbor list optimization. Examples: ```py # Basic usage with default parameters morse_model = MorseModel(device=torch.device("cuda")) results = morse_model(sim_state) # Model parameterized for O-H bonds in water, atomic units oh_model = MorseModel( sigma=0.96, epsilon=4.52, alpha=2.0, compute_forces=True, compute_stress=True, ) ``` """ def __init__( self, sigma: float = 1.0, epsilon: float = 5.0, alpha: float = 5.0, device: torch.device | None = None, dtype: torch.dtype = torch.float32, *, # Force keyword-only arguments compute_forces: bool = False, compute_stress: bool = False, per_atom_energies: bool = False, per_atom_stresses: bool = False, use_neighbor_list: bool = True, cutoff: float | None = None, ) -> None: """Initialize the Morse potential calculator. Creates a model with specified interaction parameters and computational flags. The Morse potential is defined by three key parameters: sigma (equilibrium distance), epsilon (dissociation energy), and alpha (width control). Args: sigma (float): Equilibrium bond distance (r_e) in distance units. Defaults to 1.0. epsilon (float): Dissociation energy (D_e) in energy units. Defaults to 5.0. alpha (float): Controls the width/steepness of the potential well. Larger values create a narrower well. Defaults to 5.0. device (torch.device | None): Device to run computations on. If None, uses CPU. Defaults to None. dtype (torch.dtype): Data type for calculations. Defaults to torch.float32. compute_forces (bool): Whether to compute forces. Defaults to False. compute_stress (bool): Whether to compute stress tensor. Defaults to False. per_atom_energies (bool): Whether to compute per-atom energy decomposition. Defaults to False. per_atom_stresses (bool): Whether to compute per-atom stress decomposition. Defaults to False. use_neighbor_list (bool): Whether to use a neighbor list for optimization. Significantly faster for large systems. Defaults to True. cutoff (float | None): Cutoff distance for interactions in distance units. If None, uses 2.5*sigma. Defaults to None. Examples: ```py # Basic model with default parameters model = MorseModel() # Model for diatomic hydrogen model = MorseModel( sigma=0.74, # Å epsilon=4.75, # eV alpha=1.94, # Steepness parameter compute_forces=True, ) ``` Notes: The alpha parameter can be related to the harmonic force constant k and dissociation energy D_e by: alpha = sqrt(k/(2*D_e)) """ super().__init__() self._device = device or torch.device("cpu") self._dtype = dtype self._compute_forces = compute_forces self._compute_stress = compute_stress self._per_atom_energies = per_atom_energies self._per_atom_stresses = per_atom_stresses self.use_neighbor_list = use_neighbor_list # Convert parameters to tensors self.sigma = torch.tensor(sigma, dtype=self.dtype, device=self.device) self.cutoff = torch.tensor( cutoff or 2.5 * sigma, dtype=self.dtype, device=self.device ) self.epsilon = torch.tensor(epsilon, dtype=self.dtype, device=self.device) self.alpha = torch.tensor(alpha, dtype=self.dtype, device=self.device)
[docs] def unbatched_forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: """Compute Morse potential properties for a single unbatched system. Internal implementation that processes a single, non-batched simulation state. This method handles the core computations of pair interactions, including neighbor list construction, distance calculations, and property computation. Args: state (SimState | StateDict): Single, non-batched simulation state or equivalent dictionary containing atomic positions, cell vectors, and other system information. Returns: dict[str, torch.Tensor]: Dictionary of computed properties: - "energy": Total potential energy (scalar) - "forces": Atomic forces with shape [n_atoms, 3] (if compute_forces=True) - "stress": Stress tensor with shape [3, 3] (if compute_stress=True) - "energies": Per-atom energies with shape [n_atoms] (if per_atom_energies=True) - "stresses": Per-atom stresses with shape [n_atoms, 3, 3] (if per_atom_stresses=True) Notes: This method can work with both neighbor list and full pairwise calculations. In both cases, interactions are truncated at the cutoff distance. """ if isinstance(state, dict): state = SimState(**state, masses=torch.ones_like(state["positions"])) positions = state.positions cell = state.row_vector_cell cell = cell.squeeze() pbc = state.pbc if self.use_neighbor_list: mapping, shifts = vesin_nl_ts( positions=positions, cell=cell, pbc=pbc, cutoff=self.cutoff, sort_id=False, ) dr_vec, distances = get_pair_displacements( positions=positions, cell=cell, pbc=pbc, pairs=mapping, shifts=shifts, ) else: dr_vec, distances = get_pair_displacements( positions=positions, cell=cell, pbc=pbc, ) mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self.device) distances = distances.masked_fill(mask, float("inf")) mask = distances < self.cutoff i, j = torch.where(mask) mapping = torch.stack([j, i]) dr_vec = dr_vec[mask] distances = distances[mask] # Calculate pair energies and apply cutoff pair_energies = morse_pair( distances, sigma=self.sigma, epsilon=self.epsilon, alpha=self.alpha ) mask = distances < self.cutoff pair_energies = torch.where(mask, pair_energies, torch.zeros_like(pair_energies)) # Initialize results with total energy (sum/2 to avoid double counting) results = {"energy": 0.5 * pair_energies.sum()} if self._per_atom_energies: atom_energies = torch.zeros( positions.shape[0], dtype=self.dtype, device=self.device ) atom_energies.index_add_(0, mapping[0], 0.5 * pair_energies) atom_energies.index_add_(0, mapping[1], 0.5 * pair_energies) results["energies"] = atom_energies if self.compute_forces or self.compute_stress: pair_forces = morse_pair_force( distances, sigma=self.sigma, epsilon=self.epsilon, alpha=self.alpha ) pair_forces = torch.where(mask, pair_forces, torch.zeros_like(pair_forces)) force_vectors = (pair_forces / distances)[:, None] * dr_vec if self.compute_forces: forces = torch.zeros_like(state.positions) forces.index_add_(0, mapping[0], -force_vectors) forces.index_add_(0, mapping[1], force_vectors) results["forces"] = forces if self.compute_stress and state.cell is not None: stress_per_pair = torch.einsum("...i,...j->...ij", dr_vec, force_vectors) volume = torch.abs(torch.linalg.det(state.cell)) results["stress"] = -stress_per_pair.sum(dim=0) / volume if self._per_atom_stresses: atom_stresses = torch.zeros( (state.positions.shape[0], 3, 3), dtype=self.dtype, device=self.device, ) atom_stresses.index_add_(0, mapping[0], -0.5 * stress_per_pair) atom_stresses.index_add_(0, mapping[1], -0.5 * stress_per_pair) results["stresses"] = atom_stresses / volume return results
[docs] def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: """Compute Morse potential energies, forces, and stresses for a system. Main entry point for Morse potential calculations that handles batched states by dispatching each batch to the unbatched implementation and combining results. Args: state (SimState | StateDict): Input state containing atomic positions, cell vectors, and other system information. Can be a SimState object or a dictionary with the same keys. Returns: dict[str, torch.Tensor]: Dictionary of computed properties: - "energy": Potential energy with shape [n_batches] - "forces": Atomic forces with shape [n_atoms, 3] (if compute_forces=True) - "stress": Stress tensor with shape [n_batches, 3, 3] (if compute_stress=True) - May include additional outputs based on configuration Raises: ValueError: If batch cannot be inferred for multi-cell systems. Examples: ```py # Compute properties for a simulation state model = MorseModel(compute_forces=True) results = model(sim_state) energy = results["energy"] # Shape: [n_batches] forces = results["forces"] # Shape: [n_atoms, 3] ``` """ if isinstance(state, dict): state = SimState(**state, masses=torch.ones_like(state["positions"])) if state.batch is None and state.cell.shape[0] > 1: raise ValueError("Batch can only be inferred for batch size 1.") outputs = [self.unbatched_forward(state[i]) for i in range(state.n_batches)] properties = outputs[0] # we always return tensors # per atom properties are returned as (atoms, ...) tensors # global properties are returned as shape (..., n) tensors results = {} for key in ("stress", "energy"): if key in properties: results[key] = torch.stack([out[key] for out in outputs]) for key in ("forces",): if key in properties: results[key] = torch.cat([out[key] for out in outputs], dim=0) return results