"""ORB: PyTorch implementation of ORB models for atomistic simulations.
This module provides a TorchSim wrapper of the ORB models for computing
energies, forces, and stresses of atomistic systems. It serves as a wrapper around
the ORB models library, integrating it with the torch_sim framework to enable seamless
simulation of atomistic systems with machine learning potentials.
The OrbModel class adapts ORB models to the ModelInterface protocol,
allowing them to be used within the broader torch_sim simulation framework.
Notes:
This implementation requires orb_models to be installed and accessible.
It supports various model configurations through model instances or model paths.
"""
from __future__ import annotations
import typing
from pathlib import Path
import torch
from torch_sim.elastic import voigt_6_to_full_3x3_stress
from torch_sim.models.interface import ModelInterface
from torch_sim.state import SimState
try:
from ase.geometry import cell_to_cellpar
from orb_models.forcefield import featurization_utilities as feat_util
from orb_models.forcefield.atomic_system import SystemConfig
from orb_models.forcefield.base import AtomGraphs, _map_concat
from orb_models.forcefield.graph_regressor import GraphRegressor
try:
from orb_models.forcefield.featurization_utilities import EdgeCreationMethod
except ImportError as exp:
raise ImportError(
"Orb model version is too old, interface requires >v0.4.2. If release is "
"not yet available, install from github."
) from exp
except ImportError:
class OrbModel(torch.nn.Module, ModelInterface):
"""ORB model wrapper for torch_sim.
This class is a placeholder for the OrbModel class.
It raises an ImportError if orb_models is not installed.
"""
def __init__(self, *args, **kwargs) -> None: # noqa: ARG002
"""Dummy constructor to raise ImportError."""
raise ImportError("orb_models must be installed to use this model.")
if typing.TYPE_CHECKING:
from orb_models.forcefield.conservative_regressor import (
ConservativeForcefieldRegressor,
)
from orb_models.forcefield.featurization_utilities import EdgeCreationMethod
from orb_models.forcefield.graph_regressor import GraphRegressor
from torch_sim.typing import StateDict
[docs]
def state_to_atom_graphs( # noqa: PLR0915
state: SimState,
*,
wrap: bool = True,
edge_method: EdgeCreationMethod | None = None,
system_config: SystemConfig | None = None,
max_num_neighbors: int | None = None,
system_id: int | None = None, # noqa: ARG001
half_supercell: bool = False,
device: torch.device | None = None,
output_dtype: torch.dtype | None = None,
graph_construction_dtype: torch.dtype | None = None,
) -> AtomGraphs:
"""Convert a SimState object into AtomGraphs format, ready for use in an ORB model.
Args:
state: SimState object containing atomic positions, cell, and atomic numbers
wrap: Whether to wrap atomic positions into the central unit cell (if there is
one).
edge_method (EdgeCreationMethod, optional): The method to use for graph edge
construction. If None, the edge method is chosen automatically based on
device and system size.
system_config: The system configuration to use for graph construction.
max_num_neighbors: Maximum number of neighbors each node can send messages to.
If None, will use system_config.max_num_neighbors.
system_id: Optional index that is relative to a particular dataset.
half_supercell (bool): Whether to use half the supercell for graph construction.
This can improve performance for large systems.
device: The device to put the tensors on.
output_dtype: The dtype to use for all floating point tensors stored on the
AtomGraphs object.
graph_construction_dtype: The dtype to use for floating point tensors in the
graph construction.
Returns:
AtomGraphs object containing the graph representation of the atomic system
"""
if system_config is None:
system_config = SystemConfig(radius=6.0, max_num_neighbors=20)
# Handle batch information if present
n_node = torch.bincount(state.batch)
# Set default dtype if not provided
output_dtype = torch.get_default_dtype() if output_dtype is None else output_dtype
graph_construction_dtype = (
torch.get_default_dtype()
if graph_construction_dtype is None
else graph_construction_dtype
)
# Extract data from SimState
positions = state.positions
row_vector_cell = (
state.row_vector_cell
) # Orb uses row vector cell convention for neighbor list
atomic_numbers = state.atomic_numbers.long()
# Create PBC tensor based on state.pbc
pbc = torch.tensor([state.pbc, state.pbc, state.pbc], dtype=torch.bool)
max_num_neighbors = max_num_neighbors or system_config.max_num_neighbors
# Get atom embeddings for the model
n_atoms = len(atomic_numbers)
k_hot = (
system_config.diffuse_atom_types
if hasattr(system_config, "diffuse_atom_types")
else False
)
if k_hot:
atom_type_embedding = torch.ones(n_atoms, 118) * -feat_util.ATOM_TYPE_K
atom_type_embedding[torch.arange(n_atoms), atomic_numbers] = feat_util.ATOM_TYPE_K
else:
atom_type_embedding = torch.nn.functional.one_hot(atomic_numbers, num_classes=118)
atomic_numbers_embedding = atom_type_embedding.to(output_dtype)
# Wrap positions into the central cell if needed
if wrap and (torch.any(row_vector_cell != 0) and torch.any(pbc)):
positions = feat_util.batch_map_to_pbc_cell(positions, row_vector_cell, n_node)
n_systems = state.batch.max().item() + 1
# Prepare lists to collect data from each system
all_edges = []
all_vectors = []
all_unit_shifts = []
num_edges = []
node_feats_list = []
edge_feats_list = []
graph_feats_list = []
# Process each system in a single loop
offset = 0
for i in range(n_systems):
batch_mask = state.batch == i
positions_per_system = positions[batch_mask]
atomic_numbers_per_system = atomic_numbers[batch_mask]
atomic_numbers_embedding_per_system = atomic_numbers_embedding[batch_mask]
cell_per_system = row_vector_cell[i]
pbc_per_system = pbc
# Compute edges directly for this system
edges, vectors, unit_shifts = feat_util.compute_pbc_radius_graph(
positions=positions_per_system,
cell=cell_per_system,
pbc=pbc_per_system,
radius=system_config.radius,
max_number_neighbors=max_num_neighbors,
edge_method=edge_method,
half_supercell=half_supercell,
device=device,
)
# Adjust indices for the global batch
all_edges.append(edges + offset)
all_vectors.append(vectors)
all_unit_shifts.append(unit_shifts)
num_edges.append(len(edges[0]))
# Calculate lattice parameters
lattice_per_system = torch.from_numpy(
cell_to_cellpar(cell_per_system.squeeze(0).cpu().numpy())
)
# Create features dictionaries
node_feats = {
"positions": positions_per_system,
"atomic_numbers": atomic_numbers_per_system.to(torch.long),
"atomic_numbers_embedding": atomic_numbers_embedding_per_system,
"atom_identity": torch.arange(
len(positions_per_system), device=positions_per_system.device
).to(torch.long),
}
edge_feats = {
"vectors": vectors,
"unit_shifts": unit_shifts,
}
graph_feats = {
"cell": cell_per_system,
"pbc": pbc_per_system,
"lattice": lattice_per_system.to(device=positions_per_system.device),
}
# Add batch dimension to non-scalar graph features
graph_feats = {
k: v.unsqueeze(0) if v.numel() > 1 else v for k, v in graph_feats.items()
}
node_feats_list.append(node_feats)
edge_feats_list.append(edge_feats)
graph_feats_list.append(graph_feats)
# Update offset for next system
offset += len(positions_per_system)
# Concatenate all the edge data
edge_index = torch.cat(all_edges, dim=1)
unit_shifts = torch.cat(all_unit_shifts, dim=0)
batch_num_edges = torch.tensor(num_edges, dtype=torch.int64, device=device)
senders, receivers = edge_index[0], edge_index[1]
# Create and return AtomGraphs object
return AtomGraphs(
senders=senders,
receivers=receivers,
n_node=n_node,
n_edge=batch_num_edges,
node_features=_map_concat(node_feats_list),
edge_features=_map_concat(edge_feats_list),
system_features=_map_concat(graph_feats_list),
node_targets={}, # No targets since we're using for inference
edge_targets={},
system_targets={},
fix_atoms=None, # No fixed atoms in SimState
tags=None, # No tags in SimState
radius=system_config.radius,
max_num_neighbors=torch.tensor([max_num_neighbors] * len(n_node)),
system_id=None,
).to(device=device, dtype=output_dtype)
[docs]
class OrbModel(torch.nn.Module, ModelInterface):
"""Computes atomistic energies, forces and stresses using an ORB model.
This class wraps an ORB model to compute energies, forces, and stresses for
atomistic systems. It handles model initialization, configuration, and
provides a forward pass that accepts a SimState object and returns model
predictions.
Attributes:
model (Union[GraphRegressor, ConservativeForcefieldRegressor]): The ORB model
system_config (SystemConfig): Configuration for the atomic system
conservative (bool): Whether to use conservative forces/stresses calculation
implemented_properties (list): Properties the model can compute
_dtype (torch.dtype): Data type used for computation
_device (torch.device): Device where computation is performed
_edge_method (EdgeCreationMethod): Method for creating edges in the graph
_max_num_neighbors (int): Maximum number of neighbors for each atom
_half_supercell (bool): Whether to use half supercell optimization
_memory_scales_with (str): What the memory usage scales with
Examples:
>>> model = OrbModel(model=loaded_orb_model, compute_stress=True)
>>> results = model(state)
"""
def __init__(
self,
model: GraphRegressor | ConservativeForcefieldRegressor | str | Path,
*, # force remaining arguments to be keyword-only
conservative: bool | None = None,
compute_stress: bool = True,
compute_forces: bool = True,
system_config: SystemConfig | None = None,
max_num_neighbors: int | None = None,
edge_method: EdgeCreationMethod | None = None,
half_supercell: bool | None = None,
device: torch.device | str | None = None,
dtype: torch.dtype = torch.float32,
) -> None:
"""Initialize the OrbModel with specified configuration.
Loads an ORB model from either a model object or a model path.
Sets up the model parameters for subsequent use in energy and force calculations.
Args:
model (Union[GraphRegressor, ConservativeForcefieldRegressor, str, Path]):
Either a model object or a path to a saved model
conservative (bool | None): Whether to use conservative forces/stresses
If None, determined based on model type
compute_stress (bool): Whether to compute stress tensor
compute_forces (bool): Whether to compute forces
system_config (SystemConfig | None): Configuration for the atomic system
If None, defaults to SystemConfig(radius=6.0, max_num_neighbors=20)
max_num_neighbors (int | None): Maximum number of neighbors for each atom
edge_method (EdgeCreationMethod | None): Method for creating edges
half_supercell (bool | None): Whether to use half supercell optimization
If None, determined based on system size
device (torch.device | str | None): Device to run the model on
dtype (torch.dtype | None): Data type for computation
Raises:
ValueError: If conservative mode is requested but model doesn't support it
ImportError: If orb_models is not installed
"""
super().__init__()
self._device = device or torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
if isinstance(self._device, str):
self._device = torch.device(self._device)
self._dtype = dtype
self._compute_stress = compute_stress
self._compute_forces = compute_forces
# Set up system configuration
self.system_config = system_config or model.system_config
self._max_num_neighbors = max_num_neighbors
self._edge_method = edge_method
self._half_supercell = half_supercell
self.conservative = conservative
# Load model if path is provided
if isinstance(model, str | Path):
model = torch.load(model, map_location=self._device)
self.model = model.to(self._device)
self.model = self.model.eval()
if self.dtype is not None:
self.model = self.model.to(dtype=self.dtype)
# Determine if the model is conservative
model_is_conservative = hasattr(self.model, "grad_forces_name")
if self.conservative is None:
self.conservative = model_is_conservative
if self.conservative and not model_is_conservative:
raise ValueError(
"Conservative mode requested, but model is not a "
"ConservativeForcefieldRegressor."
)
# Set up implemented properties
self.implemented_properties = self.model.properties
# Add forces and stress to implemented properties if conservative model
if self.conservative:
self.implemented_properties.extend(["forces", "stress"])
[docs]
def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]:
"""Perform forward pass to compute energies, forces, and other properties.
Takes a simulation state and computes the properties implemented by the model,
such as energy, forces, and stresses.
Args:
state (SimState | StateDict): State object containing positions, cells,
atomic numbers, and other system information. If a dictionary is provided,
it will be converted to a SimState.
Returns:
dict: Dictionary of model predictions, which may include:
- energy (torch.Tensor): Energy with shape [batch_size]
- forces (torch.Tensor): Forces with shape [n_atoms, 3]
- stress (torch.Tensor): Stress tensor with shape [batch_size, 3, 3],
if compute_stress is True
Notes:
The state is automatically transferred to the model's device if needed.
All output tensors are detached from the computation graph.
"""
if isinstance(state, dict):
state = SimState(**state, masses=torch.ones_like(state["positions"]))
if state.device != self._device:
state = state.to(self._device)
half_supercell = (
torch.max(torch.det(state.cell)) > 1000
if self._half_supercell is None
else self._half_supercell
)
# Convert state to atom graphs
batch = state_to_atom_graphs(
state,
system_config=self.system_config,
max_num_neighbors=self._max_num_neighbors,
edge_method=self._edge_method,
half_supercell=half_supercell,
device=self.device,
)
# Run forward pass
predictions = self.model.predict(batch)
results = {}
model_has_direct_heads = (
"forces" in self.model.heads and "stress" in self.model.heads
)
for prop in self.implemented_properties:
# The model has no direct heads for forces/stress, so we skip these props.
if not model_has_direct_heads and prop == "forces":
continue
if not model_has_direct_heads and prop == "stress":
continue
_property = "energy" if prop == "free_energy" else prop
results[prop] = predictions[_property].squeeze()
if self.conservative:
results["forces"] = results[self.model.grad_forces_name]
results["stress"] = results[self.model.grad_stress_name]
if "stress" in results and results["stress"].shape[-1] == 6:
results["stress"] = voigt_6_to_full_3x3_stress(results["stress"])
return results