Source code for torch_sim.models.sevennet

"""TorchSim wrapper for SevenNet models."""

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING

import torch

from torch_sim.elastic import voigt_6_to_full_3x3_stress
from torch_sim.models.interface import ModelInterface
from torch_sim.neighbors import vesin_nl_ts
from torch_sim.state import SimState


if TYPE_CHECKING:
    from collections.abc import Callable

    from sevenn.nn.sequential import AtomGraphSequential

    from torch_sim.typing import StateDict


try:
    import sevenn._keys as key
    import torch
    from sevenn.atom_graph_data import AtomGraphData
    from sevenn.calculator import torch_script_type
    from torch_geometric.loader.dataloader import Collater

except ImportError:

    class SevenNetModel(torch.nn.Module, ModelInterface):
        """SevenNet model wrapper for torch_sim.

        This class is a placeholder for the SevenNetModel class.
        It raises an ImportError if sevenn is not installed.
        """

        def __init__(self, *args, **kwargs) -> None:  # noqa: ARG002
            """Dummy constructor to raise ImportError."""
            raise ImportError("sevenn must be installed to use this model.")


[docs] class SevenNetModel(torch.nn.Module, ModelInterface): """Computes atomistic energies, forces and stresses using an SevenNet model. This class wraps an SevenNet model to compute energies, forces, and stresses for atomistic systems. It handles model initialization, configuration, and provides a forward pass that accepts a SimState object and returns model predictions. Examples: >>> model = SevenNetModel(model=loaded_sevenn_model) >>> results = model(state) """ def __init__( self, model: AtomGraphSequential, *, # force remaining arguments to be keyword-only modal: str | None = None, neighbor_list_fn: Callable = vesin_nl_ts, device: torch.device | str | None = None, dtype: torch.dtype = torch.float32, ) -> None: """Initialize the SevenNetModel with specified configuration. Loads an SevenNet model from either a model object or a model path. Sets up the model parameters for subsequent use in energy and force calculations. Args: model (AtomGraphSequential): The SevenNet model to wrap. modal (str | None): modal (fidelity) if given model is multi-modal model. for 7net-mf-ompa, it should be one of 'mpa' (MPtrj + sAlex) or 'omat24' (OMat24). neighbor_list_fn (Callable): Neighbor list function to use. Default is vesin_nl_ts. device (torch.device | str | None): Device to run the model on dtype (torch.dtype | None): Data type for computation Raises: ValueError: the model doesn't have a cutoff ValueError: the model has a modal_map but modal is not given ValueError: the modal given is not in the modal_map """ super().__init__() self._device = device or torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) if isinstance(self._device, str): self._device = torch.device(self._device) if torch.dtype is not torch.float32: warnings.warn( "SevenNetModel currently only supports" "float32, but received different dtype", UserWarning, stacklevel=2, ) self._dtype = dtype self._memory_scales_with = "n_atoms_x_density" self._compute_stress = True self._compute_forces = True if model.cutoff == 0.0: raise ValueError("Model cutoff seems not initialized") model.eval_type_map = torch.tensor( data=True, ) # TODO: from sevenn not sure if needed model.set_is_batch_data(True) model_loaded = model self.cutoff = torch.tensor(model.cutoff) self.neighbor_list_fn = neighbor_list_fn self.model = model_loaded self.modal = None modal_map = self.model.modal_map if modal_map: modal_ava = list(modal_map.keys()) if not modal: raise ValueError(f"modal argument missing (avail: {modal_ava})") if modal not in modal_ava: raise ValueError(f"unknown modal {modal} (not in {modal_ava})") self.modal = modal elif not self.model.modal_map and modal: warnings.warn( f"modal={modal} is ignored as model has no modal_map", stacklevel=2, ) self.model = model.to(self._device) self.model = self.model.eval() if self.dtype is not None: self.model = self.model.to(dtype=self.dtype) self.implemented_properties = [ "energy", "forces", "stress", ]
[docs] def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: """Perform forward pass to compute energies, forces, and other properties. Takes a simulation state and computes the properties implemented by the model, such as energy, forces, and stresses. Args: state (SimState | StateDict): State object containing positions, cells, atomic numbers, and other system information. If a dictionary is provided, it will be converted to a SimState. Returns: dict: Dictionary of model predictions, which may include: - energy (torch.Tensor): Energy with shape [batch_size] - forces (torch.Tensor): Forces with shape [n_atoms, 3] - stress (torch.Tensor): Stress tensor with shape [batch_size, 3, 3], if compute_stress is True Notes: The state is automatically transferred to the model's device if needed. All output tensors are detached from the computation graph. """ if isinstance(state, dict): state = SimState(**state, masses=torch.ones_like(state["positions"])) if state.device != self._device: state = state.to(self._device) # TODO: is this clone necessary? state = state.clone() data_list = [] for b in range(state.batch.max().item() + 1): batch_mask = state.batch == b pos = state.positions[batch_mask] # SevenNet uses row vector cell convention for neighbor list row_vector_cell = state.row_vector_cell[b] pbc = state.pbc atomic_numbers = state.atomic_numbers[batch_mask] edge_idx, shifts_idx = self.neighbor_list_fn( positions=pos, cell=row_vector_cell, pbc=pbc, cutoff=self.cutoff, ) shifts = torch.mm(shifts_idx, row_vector_cell) edge_vec = pos[edge_idx[1]] - pos[edge_idx[0]] + shifts data = { key.NODE_FEATURE: atomic_numbers, key.ATOMIC_NUMBERS: atomic_numbers.to( dtype=torch.int64, device=self.device ), key.POS: pos, key.EDGE_IDX: edge_idx, key.EDGE_VEC: edge_vec, key.CELL: row_vector_cell, key.CELL_SHIFT: shifts_idx, key.CELL_VOLUME: torch.det(row_vector_cell), key.NUM_ATOMS: torch.tensor(len(atomic_numbers), device=self.device), key.DATA_MODALITY: self.modal, } data[key.INFO] = {} data = AtomGraphData(**data) data_list.append(data) batched_data = Collater([], follow_batch=None, exclude_keys=None)(data_list) batched_data.to(self.device) if isinstance(self.model, torch_script_type): batched_data[key.NODE_FEATURE] = torch.tensor( [self.type_map[z.item()] for z in data[key.NODE_FEATURE]], dtype=torch.int64, device=self.device, ) batched_data[key.POS].requires_grad_( requires_grad=True ) # backward compatibility batched_data[key.EDGE_VEC].requires_grad_(requires_grad=True) batched_data = batched_data.to_dict() del batched_data["data_info"] output = self.model(batched_data) results = {} energy = output[key.PRED_TOTAL_ENERGY] if energy is not None: results["energy"] = energy.detach() else: results["energy"] = torch.zeros( state.batch.max().item() + 1, device=self.device ) forces = output[key.PRED_FORCE] if forces is not None: results["forces"] = forces.detach() stress = output[key.PRED_STRESS] if stress is not None: results["stress"] = voigt_6_to_full_3x3_stress(stress.detach()) return results