Source code for torch_sim.models.metatomic

"""Wrapper for metatomic-based models in TorchSim.

This module provides a TorchSim wrapper of metatomic models for computing
energies, forces, and stresses for atomistic systems, including batched computations
for multiple systems simultaneously.

The MetatomicModel class adapts metatomic models to the ModelInterface protocol,
allowing them to be used within the broader torch_sim simulation framework.

Notes:
    This module depends on the metatomic-torch package.
"""

import traceback
import warnings
from pathlib import Path
from typing import Any

import torch
import vesin.metatomic

import torch_sim as ts
from torch_sim.models.interface import ModelInterface
from torch_sim.typing import StateDict


try:
    from metatomic.torch import (
        ModelEvaluationOptions,
        ModelOutput,
        System,
        load_atomistic_model,
    )
    from metatrain.utils.io import load_model

except ImportError as exc:
    warnings.warn(f"Metatomic import failed: {traceback.format_exc()}", stacklevel=2)

    class MetatomicModel(torch.nn.Module, ModelInterface):
        """Metatomic model wrapper for torch_sim.

        This class is a placeholder for the MetatomicModel class.
        It raises an ImportError if metatomic is not installed.
        """

        def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None:
            """Dummy init for type checking."""
            raise err


[docs] class MetatomicModel(torch.nn.Module, ModelInterface): """Computes energies for a list of systems using a metatomic model. This class wraps a metatomic model to compute energies, forces, and stresses for atomic systems within the TorchSim framework. It supports batched calculations for multiple systems and handles the necessary transformations between TorchSim's data structures and metatomic's expected inputs. Attributes: ... """ def __init__( self, model: str | Path | None = None, extensions_path: str | Path | None = None, device: torch.device | str | None = None, *, check_consistency: bool = False, compute_forces: bool = True, compute_stress: bool = True, ) -> None: """Initialize the metatomic model for energy, force and stress calculations. Sets up a metatomic model for energy, force, and stress calculations within the TorchSim framework. The model can be initialized with atomic numbers and batch indices, or these can be provided during the forward pass. Args: model (str | Path | None): Path to the metatomic model file or a pre-defined model name. Currently only "pet-mad" (https://arxiv.org/abs/2503.14118) is supported as a pre-defined model. If None, defaults to "pet-mad". extensions_path (str | Path | None): Optional, path to the folder containing compiled extensions for the model. device (torch.device | None): Device on which to run the model. If None, defaults to "cuda" if available, otherwise "cpu". check_consistency (bool): Whether to perform various consistency checks during model evaluation. This should only be used in case of anomalous behavior, as it can hurt performance significantly. compute_forces (bool): Whether to compute forces. compute_stress (bool): Whether to compute stresses. Raises: TypeError: If model is neither a path nor "pet-mad". """ super().__init__() if model is None: raise ValueError( "A model path, or the name of a pre-defined model, must be provided. " 'Currently only "pet-mad" is available as a pre-defined model.' ) if model == "pet-mad": path = "https://huggingface.co/lab-cosmo/pet-mad/resolve/main/models/pet-mad-latest.ckpt" self._model = load_model(path).export() elif model.endswith(".ckpt"): path = model self._model = load_model(path).export() elif model.endswith(".pt"): path = model self._model = load_atomistic_model(path, extensions_path) else: raise ValueError('Model must be a path to a .ckpt/.pt file, or "pet-mad".') if "energy" not in self._model.capabilities().outputs: raise ValueError( "This model does not support energy predictions. " "The model must have an `energy` output to be used in torch-sim." ) self._device = device or ("cuda" if torch.cuda.is_available() else "cpu") if isinstance(self._device, str): self._device = torch.device(self._device) if self._device.type not in self._model.capabilities().supported_devices: raise ValueError( f"Model does not support device {self._device}. Supported devices: " f"{self._model.capabilities().supported_devices}. You might want to " f"set the `device` argument to a supported device." ) self._dtype = getattr(torch, self._model.capabilities().dtype) self._model.to(self._device) self._compute_forces = compute_forces self._compute_stress = compute_stress self._memory_scales_with = "n_atoms_x_density" # for the majority of models self._check_consistency = check_consistency self._requested_neighbor_lists = self._model.requested_neighbor_lists() self._evaluation_options = ModelEvaluationOptions( length_unit="angstrom", outputs={ "energy": ModelOutput( quantity="energy", unit="eV", per_atom=False, ) }, )
[docs] def forward( # noqa: C901, PLR0915 self, state: ts.SimState | StateDict, ) -> dict[str, torch.Tensor]: """Compute energies, forces, and stresses for the given atomic systems. Processes the provided state information and computes energies, forces, and stresses using the underlying metatomic model. Handles batched calculations for multiple systems as well as constructing the necessary neighbor lists. Args: state (SimState | StateDict): State object containing positions, cell, and other system information. Can be either a SimState object or a dictionary with the relevant fields. Returns: dict[str, torch.Tensor]: Computed properties: - 'energy': System energies with shape [n_systems] - 'forces': Atomic forces with shape [n_atoms, 3] if compute_forces=True - 'stress': System stresses with shape [n_systems, 3, 3] if compute_stress=True """ # Extract required data from input if isinstance(state, dict): state = ts.SimState(**state, masses=torch.ones_like(state["positions"])) # Input validation is already done inside the forward method of the # AtomisticModel class, so we don't need to do it again here. atomic_numbers = state.atomic_numbers cell = state.row_vector_cell positions = state.positions pbc = state.pbc # Check dtype (metatomic models require a specific input dtype) if positions.dtype != self._dtype: raise TypeError( f"Positions dtype {positions.dtype} does not match model dtype " f"{self._dtype}" ) # Compared to other models, metatomic models have two peculiarities: # - different structures are fed to the models separately as a list of System # objects, and not as a single graph-like batch # - the model does not compute forces and stresses itself, but rather the # caller code needs to call torch.autograd.grad or similar to compute them # from the energy output # Process each system separately systems: list[System] = [] strains = [] for b in range(len(cell)): system_mask = state.batch == b system_positions = positions[system_mask] system_cell = cell[b] system_pbc = torch.tensor( [pbc, pbc, pbc], device=self._device, dtype=torch.bool ) system_atomic_numbers = atomic_numbers[system_mask] # Create a System object for this system if self._compute_forces: system_positions.requires_grad_() if self._compute_stress: strain = torch.eye( 3, device=self._device, dtype=self._dtype, requires_grad=True ) system_positions = system_positions @ strain system_cell = system_cell @ strain strains.append(strain) systems.append( System( positions=system_positions, types=system_atomic_numbers, cell=system_cell, pbc=system_pbc, ) ) # Calculate the required neighbor list(s) for all the systems # move data to CPU because vesin only supports CPU for now systems = [system.to(device="cpu") for system in systems] vesin.metatomic.compute_requested_neighbors( systems, system_length_unit="Angstrom", model=self._model ) # move back to the proper device systems = [system.to(device=self.device) for system in systems] # Get model output model_outputs = self._model( systems=systems, options=self._evaluation_options, check_consistency=self._check_consistency, ) results = {} results["energy"] = model_outputs["energy"].block().values.detach().squeeze(-1) # Compute forces and/or stresses if requested tensors_for_autograd = [] if self._compute_forces: for system in systems: tensors_for_autograd.append(system.positions) # noqa: PERF401 if self._compute_stress: for strain in strains: tensors_for_autograd.append(strain) # noqa: PERF402 if self._compute_forces or self._compute_stress: derivatives = torch.autograd.grad( outputs=model_outputs["energy"].block().values, inputs=tensors_for_autograd, grad_outputs=torch.ones_like(model_outputs["energy"].block().values), ) else: derivatives = [] results_by_system: dict[str, list[torch.Tensor]] = {} if self._compute_forces and self._compute_stress: results_by_system["forces"] = [-d for d in derivatives[: len(systems)]] results_by_system["stress"] = [ d / torch.abs(torch.det(system.cell.detach())) for d, system in zip(derivatives[len(systems) :], systems, strict=False) ] elif self._compute_forces: results_by_system["forces"] = [-d for d in derivatives] elif self._compute_stress: results_by_system["stress"] = [ d / torch.abs(torch.det(system.cell.detach())) for d, system in zip(derivatives, systems, strict=False) ] else: pass # Concatenate/stack forces and stresses if self._compute_forces: if len(results_by_system["forces"]) > 0: results["forces"] = torch.cat(results_by_system["forces"]) else: results["forces"] = torch.empty_like(positions) if self._compute_stress: if len(results_by_system["stress"]) > 0: results["stress"] = torch.stack(results_by_system["stress"]) else: results["stress"] = torch.empty_like(cell) return results