GraphPESWrapper¶
- class torch_sim.models.graphpes.GraphPESWrapper(model, device=None, dtype=torch.float64, *, compute_forces=True, compute_stress=True)[source]¶
Bases:
Module
,ModelInterface
Wrapper for GraphPESModel in TorchSim.
This class provides a TorchSim wrapper around GraphPESModel instances, allowing them to be used within the broader torch_sim simulation framework.
The graph-pes package allows for the training of existing model architectures, including SchNet, PaiNN, MACE, NequIP, TensorNet, EDDP and more. You can use any of these, as well as your own custom architectures, with this wrapper. See the the graph-pes repo for more details: https://github.com/jla-gardner/graph-pes
- Parameters:
Example
>>> from torch_sim.models import GraphPESWrapper >>> from graph_pes.models import load_model >>> model = load_model("path/to/model.pt") >>> wrapper = GraphPESWrapper(model) >>> state = SimState( ... positions=torch.randn(10, 3), ... cell=torch.eye(3), ... atomic_numbers=torch.randint(1, 104, (10,)), ... ) >>> wrapper(state)
- forward(state)[source]¶
Forward pass for the GraphPESWrapper.
- Parameters:
state (SimState | dict[Literal['positions', 'masses', 'cell', 'pbc', 'atomic_numbers', 'batch'], ~torch.Tensor]) – SimState object containing atomic positions, cell, and atomic numbers
- Returns:
Dictionary containing the computed energies, forces, and stresses (where applicable)
- Return type: