MaceModel

class torch_sim.models.mace.MaceModel(model=None, *, device=None, dtype=torch.float64, neighbor_list_fn=vesin_nl_ts, compute_forces=True, compute_stress=True, enable_cueq=False, atomic_numbers=None, batch=None)[source]

Bases: Module, ModelInterface

Computes energies for multiple systems using a MACE model.

This class wraps a MACE model to compute energies, forces, and stresses for atomic systems within the TorchSim framework. It supports batched calculations for multiple systems and handles the necessary transformations between TorchSim’s data structures and MACE’s expected inputs.

Variables:
  • r_max (float) – Cutoff radius for neighbor interactions.

  • z_table (utils.AtomicNumberTable) – Table mapping atomic numbers to indices.

  • model (Module) – The underlying MACE neural network model.

  • neighbor_list_fn (Callable) – Function used to compute neighbor lists.

  • atomic_numbers (Tensor) – Atomic numbers with shape [n_atoms].

  • batch (Tensor) – Batch indices with shape [n_atoms].

  • n_systems (int) – Number of systems in the batch.

  • n_atoms_per_system (list[int]) – Number of atoms in each system.

  • ptr (Tensor) – Pointers to the start of each system in the batch with shape [n_systems + 1].

  • total_atoms (int) – Total number of atoms across all systems.

  • node_attrs (Tensor) – One-hot encoded atomic types with shape [n_atoms, n_elements].

Parameters:
setup_from_batch(atomic_numbers, batch)[source]

Set up internal state from atomic numbers and batch indices.

Processes the atomic numbers and batch indices to prepare the model for forward pass calculations. Creates the necessary data structures for batched processing of multiple systems.

Parameters:
  • atomic_numbers (Tensor) – Atomic numbers tensor with shape [n_atoms].

  • batch (Tensor) – Batch indices tensor with shape [n_atoms] indicating which system each atom belongs to.

Return type:

None

forward(state)[source]

Compute energies, forces, and stresses for the given atomic systems.

Processes the provided state information and computes energies, forces, and stresses using the underlying MACE model. Handles batched calculations for multiple systems and constructs the necessary neighbor lists.

Parameters:

state (SimState | StateDict) – State object containing positions, cell, and other system information. Can be either a SimState object or a dictionary with the relevant fields.

Returns:

Dictionary containing:
  • ’energy’: System energies with shape [n_systems]

  • ’forces’: Atomic forces with shape [n_atoms, 3] if compute_forces=True

  • ’stress’: System stresses with shape [n_systems, 3, 3] if

    compute_stress=True

Return type:

dict[str, Tensor]

Raises:
  • ValueError – If atomic numbers are not provided either in the constructor or in the forward pass, or if provided in both places.

  • ValueError – If batch indices are not provided when needed.