GraphPESWrapper

class torch_sim.models.graphpes.GraphPESWrapper(model, device=None, dtype=torch.float64, *, compute_forces=True, compute_stress=True)[source]

Bases: Module, ModelInterface

Wrapper for GraphPESModel in TorchSim.

This class provides a TorchSim wrapper around GraphPESModel instances, allowing them to be used within the broader torch_sim simulation framework.

The graph-pes package allows for the training of existing model architectures, including SchNet, PaiNN, MACE, NequIP, TensorNet, EDDP and more. You can use any of these, as well as your own custom architectures, with this wrapper. See the the graph-pes repo for more details: https://github.com/jla-gardner/graph-pes

Parameters:
  • model (GraphPESModel | str | Path) – GraphPESModel instance, or a path to a model file

  • device (device | None) – Device to run the model on

  • dtype (dtype) – Data type for the model

  • compute_forces (bool) – Whether to compute forces

  • compute_stress (bool) – Whether to compute stress

Example

>>> from torch_sim.models import GraphPESWrapper
>>> from graph_pes.models import load_model
>>> model = load_model("path/to/model.pt")
>>> wrapper = GraphPESWrapper(model)
>>> state = SimState(
...     positions=torch.randn(10, 3),
...     cell=torch.eye(3),
...     atomic_numbers=torch.randint(1, 104, (10,)),
... )
>>> wrapper(state)
forward(state)[source]

Forward pass for the GraphPESWrapper.

Parameters:

state (SimState | dict[Literal['positions', 'masses', 'cell', 'pbc', 'atomic_numbers', 'batch'], ~torch.Tensor]) – SimState object containing atomic positions, cell, and atomic numbers

Returns:

Dictionary containing the computed energies, forces, and stresses (where applicable)

Return type:

dict[str, Tensor]