Source code for torch_sim.models.metatensor

"""Wrapper for metatensor-based models in TorchSim.

This module provides a TorchSim wrapper of metatensor models for computing
energies, forces, and stresses for atomistic systems, including batched computations
for multiple systems simultaneously.

The MetatensorModel class adapts metatensor models to the ModelInterface protocol,
allowing them to be used within the broader torch_sim simulation framework.

Notes:
    This module depends on the metatensor-torch package.
"""

from pathlib import Path

import torch
import vesin.torch.metatensor

from torch_sim.models.interface import ModelInterface
from torch_sim.state import SimState
from torch_sim.typing import StateDict


try:
    from metatensor.torch.atomistic import (
        ModelEvaluationOptions,
        ModelOutput,
        System,
        load_atomistic_model,
    )
    from metatrain.utils.io import load_model

except ImportError:

    class MetatensorModel(torch.nn.Module, ModelInterface):
        """Metatensor model wrapper for torch_sim.

        This class is a placeholder for the MetatensorModel class.
        It raises an ImportError if metatensor is not installed.
        """

        def __init__(self, *args, **kwargs) -> None:  # noqa: ARG002
            """Dummy constructor."""
            raise ImportError("metatensor must be installed to use MetatensorModel.")


[docs] class MetatensorModel(torch.nn.Module, ModelInterface): """Computes energies for a list of systems using a metatensor model. This class wraps a metatensor model to compute energies, forces, and stresses for atomic systems within the TorchSim framework. It supports batched calculations for multiple systems and handles the necessary transformations between TorchSim's data structures and metatensor's expected inputs. Attributes: ... """ def __init__( self, model: str | Path | None = None, extensions_path: str | Path | None = None, device: torch.device | str | None = None, *, check_consistency: bool = False, compute_forces: bool = True, compute_stress: bool = True, ) -> None: """Initialize the metatensor model for energy, force and stress calculations. Sets up a metatensor model for energy, force, and stress calculations within the TorchSim framework. The model can be initialized with atomic numbers and batch indices, or these can be provided during the forward pass. Args: model (str | Path | None): Path to the metatensor model file or a pre-defined model name. Currently only "pet-mad" (https://arxiv.org/abs/2503.14118) is supported as a pre-defined model. If None, defaults to "pet-mad". extensions_path (str | Path | None): Optional, path to the folder containing compiled extensions for the model. device (torch.device | None): Device on which to run the model. If None, defaults to "cuda" if available, otherwise "cpu". check_consistency (bool): Whether to perform various consistency checks during model evaluation. This should only be used in case of anomalous behavior, as it can hurt performance significantly. compute_forces (bool): Whether to compute forces. compute_stress (bool): Whether to compute stresses. Raises: TypeError: If model is neither a path nor "pet-mad". """ super().__init__() if model is None: raise ValueError( "A model path, or the name of a pre-defined model, must be provided. " 'Currently only "pet-mad" is available as a pre-defined model.' ) if model == "pet-mad": path = "https://huggingface.co/lab-cosmo/pet-mad/resolve/main/models/pet-mad-latest.ckpt" self._model = load_model(path).export() elif model.endswith(".ckpt"): path = model self._model = load_model(path).export() elif model.endswith(".pt"): path = model self._model = load_atomistic_model(path, extensions_path) else: raise ValueError('Model must be a path to a .ckpt/.pt file, or "pet-mad".') if "energy" not in self._model.capabilities().outputs: raise ValueError( "This model does not support energy predictions. " "The model must have an `energy` output to be used in torch-sim." ) self._device = device or ("cuda" if torch.cuda.is_available() else "cpu") if isinstance(self._device, str): self._device = torch.device(self._device) if self._device.type not in self._model.capabilities().supported_devices: raise ValueError( f"Model does not support device {self._device}. Supported devices: " f"{self._model.capabilities().supported_devices}. You might want to " f"set the `device` argument to a supported device." ) self._dtype = getattr(torch, self._model.capabilities().dtype) self._model.to(self._device) self._compute_forces = compute_forces self._compute_stress = compute_stress self._memory_scales_with = "n_atoms_x_density" # for the majority of models self._check_consistency = check_consistency self._requested_neighbor_lists = self._model.requested_neighbor_lists() self._evaluation_options = ModelEvaluationOptions( length_unit="angstrom", outputs={ "energy": ModelOutput( quantity="energy", unit="eV", per_atom=False, ) }, )
[docs] def forward( # noqa: C901, PLR0915 self, state: SimState | StateDict, ) -> dict[str, torch.Tensor]: """Compute energies, forces, and stresses for the given atomic systems. Processes the provided state information and computes energies, forces, and stresses using the underlying metatensor model. Handles batched calculations for multiple systems as well as constructing the necessary neighbor lists. Args: state (SimState | StateDict): State object containing positions, cell, and other system information. Can be either a SimState object or a dictionary with the relevant fields. Returns: dict[str, torch.Tensor]: Dictionary containing: - 'energy': System energies with shape [n_systems] - 'forces': Atomic forces with shape [n_atoms, 3] if compute_forces=True - 'stress': System stresses with shape [n_systems, 3, 3] if compute_stress=True """ # Extract required data from input if isinstance(state, dict): state = SimState(**state, masses=torch.ones_like(state["positions"])) # Input validation is already done inside the forward method of the # MetatensorAtomisticModel class, so we don't need to do it again here. atomic_numbers = state.atomic_numbers cell = state.row_vector_cell positions = state.positions pbc = state.pbc # Check dtype (metatensor models require a specific input dtype) if positions.dtype != self._dtype: raise TypeError( f"Positions dtype {positions.dtype} does not match model dtype " f"{self._dtype}" ) # Compared to other models, metatensor models have two peculiarities: # - different structures are fed to the models separately as a list of System # objects, and not as a single graph-like batch # - the model does not compute forces and stresses itself, but rather the # caller code needs to call torch.autograd.grad or similar to compute them # from the energy output # Process each system separately systems: list[System] = [] strains = [] for b in range(len(cell)): system_mask = state.batch == b system_positions = positions[system_mask] system_cell = cell[b] system_pbc = torch.tensor( [pbc, pbc, pbc], device=self._device, dtype=torch.bool ) system_atomic_numbers = atomic_numbers[system_mask] # Create a System object for this system if self._compute_forces: system_positions.requires_grad_() if self._compute_stress: strain = torch.eye( 3, device=self._device, dtype=self._dtype, requires_grad=True ) system_positions = system_positions @ strain system_cell = system_cell @ strain systems.append( System( positions=system_positions, types=system_atomic_numbers, cell=system_cell, pbc=system_pbc, ) ) # Calculate the required neighbor list(s) for all the systems # move data to CPU because vesin only supports CPU for now systems = [system.to(device="cpu") for system in systems] vesin.torch.metatensor.compute_requested_neighbors( systems, system_length_unit="Angstrom", model=self._model ) # move back to the proper device systems = [system.to(device=self.device) for system in systems] # Get model output model_outputs = self._model( systems=systems, options=self._evaluation_options, check_consistency=self._check_consistency, ) results = {} results["energy"] = model_outputs["energy"].block().values.detach().squeeze(-1) # Compute forces and/or stresses if requested tensors_for_autograd = [] if self._compute_forces: for system in systems: tensors_for_autograd.append(system.positions) # noqa: PERF401 if self._compute_stress: for strain in strains: tensors_for_autograd.append(strain) # noqa: PERF402 if self._compute_forces or self._compute_stress: derivatives = torch.autograd.grad( outputs=model_outputs["energy"].block().values, inputs=tensors_for_autograd, grad_outputs=torch.ones_like(model_outputs["energy"].block().values), ) else: derivatives = [] results_by_system: dict[str, list[torch.Tensor]] = {} if self._compute_forces and self._compute_stress: results_by_system["forces"] = [-d for d in derivatives[: len(systems)]] results_by_system["stress"] = [ d / torch.abs(torch.det(system.cell.detach())) for d, system in zip(derivatives[len(systems) :], systems, strict=False) ] elif self._compute_forces: results_by_system["forces"] = [-d for d in derivatives] elif self._compute_stress: results_by_system["stress"] = [ d / torch.abs(torch.det(system.cell.detach())) for d, system in zip(derivatives, systems, strict=False) ] else: pass # Concatenate/stack forces and stresses if self._compute_forces: if len(results_by_system["forces"]) > 0: results["forces"] = torch.cat(results_by_system["forces"]) else: results["forces"] = torch.empty_like(positions) if self._compute_stress: if len(results_by_system["stress"]) > 0: results["stress"] = torch.stack(results_by_system["stress"]) else: results["stress"] = torch.empty_like(cell) return results