OrbModel

class torch_sim.models.orb.OrbModel(model, *, conservative=None, compute_stress=True, compute_forces=True, system_config=None, max_num_neighbors=None, edge_method=None, half_supercell=None, device=None, dtype=torch.float32)[source]

Bases: Module, ModelInterface

Computes atomistic energies, forces and stresses using an ORB model.

This class wraps an ORB model to compute energies, forces, and stresses for atomistic systems. It handles model initialization, configuration, and provides a forward pass that accepts a SimState object and returns model predictions.

Variables:
  • model (Union[GraphRegressor, ConservativeForcefieldRegressor]) – The ORB model

  • system_config (SystemConfig) – Configuration for the atomic system

  • conservative (bool) – Whether to use conservative forces/stresses calculation

  • implemented_properties (list) – Properties the model can compute

  • _dtype (dtype) – Data type used for computation

  • _device (device) – Device where computation is performed

  • _edge_method (EdgeCreationMethod) – Method for creating edges in the graph

  • _max_num_neighbors (int) – Maximum number of neighbors for each atom

  • _half_supercell (bool) – Whether to use half supercell optimization

  • _memory_scales_with (str) – What the memory usage scales with

Parameters:
  • model (GraphRegressor | ConservativeForcefieldRegressor | str | Path)

  • conservative (bool | None)

  • compute_stress (bool)

  • compute_forces (bool)

  • system_config (SystemConfig | None)

  • max_num_neighbors (int | None)

  • edge_method (EdgeCreationMethod | None)

  • half_supercell (bool | None)

  • device (device | str | None)

  • dtype (dtype)

Examples

>>> model = OrbModel(model=loaded_orb_model, compute_stress=True)
>>> results = model(state)
forward(state)[source]

Perform forward pass to compute energies, forces, and other properties.

Takes a simulation state and computes the properties implemented by the model, such as energy, forces, and stresses.

Parameters:

state (SimState | StateDict) – State object containing positions, cells, atomic numbers, and other system information. If a dictionary is provided, it will be converted to a SimState.

Returns:

Dictionary of model predictions, which may include:
  • energy (torch.Tensor): Energy with shape [batch_size]

  • forces (torch.Tensor): Forces with shape [n_atoms, 3]

  • stress (torch.Tensor): Stress tensor with shape [batch_size, 3, 3],

    if compute_stress is True

Return type:

dict

Notes

The state is automatically transferred to the model’s device if needed. All output tensors are detached from the computation graph.