MetatensorModel

class torch_sim.models.metatensor.MetatensorModel(model=None, extensions_path=None, device=None, *, check_consistency=False, compute_forces=True, compute_stress=True)[source]

Bases: Module, ModelInterface

Computes energies for a list of systems using a metatensor model.

This class wraps a metatensor model to compute energies, forces, and stresses for atomic systems within the TorchSim framework. It supports batched calculations for multiple systems and handles the necessary transformations between TorchSim’s data structures and metatensor’s expected inputs.

Variables:

...

Parameters:
forward(state)[source]

Compute energies, forces, and stresses for the given atomic systems.

Processes the provided state information and computes energies, forces, and stresses using the underlying metatensor model. Handles batched calculations for multiple systems as well as constructing the necessary neighbor lists.

Parameters:

state (SimState | StateDict) – State object containing positions, cell, and other system information. Can be either a SimState object or a dictionary with the relevant fields.

Returns:

Dictionary containing:
  • ’energy’: System energies with shape [n_systems]

  • ’forces’: Atomic forces with shape [n_atoms, 3] if compute_forces=True

  • ’stress’: System stresses with shape [n_systems, 3, 3] if

    compute_stress=True

Return type:

dict[str, Tensor]