SevenNetModel¶
- class torch_sim.models.sevennet.SevenNetModel(model, *, modal=None, neighbor_list_fn=vesin_nl_ts, device=None, dtype=torch.float32)[source]¶
Bases:
Module
,ModelInterface
Computes atomistic energies, forces and stresses using an SevenNet model.
This class wraps an SevenNet model to compute energies, forces, and stresses for atomistic systems. It handles model initialization, configuration, and provides a forward pass that accepts a SimState object and returns model predictions.
Examples
>>> model = SevenNetModel(model=loaded_sevenn_model) >>> results = model(state)
- Parameters:
- forward(state)[source]¶
Perform forward pass to compute energies, forces, and other properties.
Takes a simulation state and computes the properties implemented by the model, such as energy, forces, and stresses.
- Parameters:
state (SimState | StateDict) – State object containing positions, cells, atomic numbers, and other system information. If a dictionary is provided, it will be converted to a SimState.
- Returns:
- Dictionary of model predictions, which may include:
energy (torch.Tensor): Energy with shape [batch_size]
forces (torch.Tensor): Forces with shape [n_atoms, 3]
- stress (torch.Tensor): Stress tensor with shape [batch_size, 3, 3],
if compute_stress is True
- Return type:
Notes
The state is automatically transferred to the model’s device if needed. All output tensors are detached from the computation graph.