OrbModel¶
- class torch_sim.models.orb.OrbModel(model, *, conservative=None, compute_stress=True, compute_forces=True, system_config=None, max_num_neighbors=None, edge_method=None, half_supercell=None, device=None, dtype=torch.float32)[source]¶
Bases:
Module
,ModelInterface
Computes atomistic energies, forces and stresses using an ORB model.
This class wraps an ORB model to compute energies, forces, and stresses for atomistic systems. It handles model initialization, configuration, and provides a forward pass that accepts a SimState object and returns model predictions.
- Variables:
model (Union[GraphRegressor, ConservativeForcefieldRegressor]) – The ORB model
system_config (SystemConfig) – Configuration for the atomic system
conservative (bool) – Whether to use conservative forces/stresses calculation
implemented_properties (list) – Properties the model can compute
_dtype (dtype) – Data type used for computation
_device (device) – Device where computation is performed
_edge_method (EdgeCreationMethod) – Method for creating edges in the graph
_max_num_neighbors (int) – Maximum number of neighbors for each atom
_half_supercell (bool) – Whether to use half supercell optimization
_memory_scales_with (str) – What the memory usage scales with
- Parameters:
Examples
>>> model = OrbModel(model=loaded_orb_model, compute_stress=True) >>> results = model(state)
- forward(state)[source]¶
Perform forward pass to compute energies, forces, and other properties.
Takes a simulation state and computes the properties implemented by the model, such as energy, forces, and stresses.
- Parameters:
state (SimState | StateDict) – State object containing positions, cells, atomic numbers, and other system information. If a dictionary is provided, it will be converted to a SimState.
- Returns:
- Dictionary of model predictions, which may include:
energy (torch.Tensor): Energy with shape [batch_size]
forces (torch.Tensor): Forces with shape [n_atoms, 3]
- stress (torch.Tensor): Stress tensor with shape [batch_size, 3, 3],
if compute_stress is True
- Return type:
Notes
The state is automatically transferred to the model’s device if needed. All output tensors are detached from the computation graph.