FairChemModel¶
- class torch_sim.models.fairchem.FairChemModel(model, neighbor_list_fn=None, *, config_yml=None, model_name=None, local_cache=None, trainer=None, cpu=False, seed=None, dtype=None, compute_stress=False, pbc=True, disable_amp=True)[source]¶
Bases:
Module
,ModelInterface
Computes atomistic energies, forces and stresses using a FairChem model.
This class wraps a FairChem model to compute energies, forces, and stresses for atomistic systems. It handles model initialization, checkpoint loading, and provides a forward pass that accepts a SimState object and returns model predictions.
The model can be initialized either with a configuration file or a pretrained checkpoint. It supports various model architectures and configurations supported by FairChem.
- Variables:
neighbor_list_fn (Callable | None) – Function to compute neighbor lists
config (dict) – Complete model configuration dictionary
trainer – FairChem trainer object that contains the model
data_object (Batch) – Data object containing system information
implemented_properties (list) – Model outputs the model can compute
pbc (bool) – Whether periodic boundary conditions are used
_dtype (dtype) – Data type used for computation
_compute_stress (bool) – Whether to compute stress tensor
_compute_forces (bool) – Whether to compute forces
_device (device) – Device where computation is performed
_reshaped_props (dict) – Properties that need reshaping after computation
- Parameters:
Examples
>>> model = FairChemModel(model="path/to/checkpoint.pt", compute_stress=True) >>> results = model(state)
- load_checkpoint(checkpoint_path, checkpoint=None)[source]¶
Load an existing trained model checkpoint.
Loads model parameters from a checkpoint file or dictionary, setting the model to inference mode.
- Parameters:
- Return type:
None
Notes
If loading fails, a message is printed but no exception is raised.
- forward(state)[source]¶
Perform forward pass to compute energies, forces, and other properties.
Takes a simulation state and computes the properties implemented by the model, such as energy, forces, and stresses.
- Parameters:
state (SimState | StateDict) – State object containing positions, cells, atomic numbers, and other system information. If a dictionary is provided, it will be converted to a SimState.
- Returns:
- Dictionary of model predictions, which may include:
energy (torch.Tensor): Energy with shape [batch_size]
forces (torch.Tensor): Forces with shape [n_atoms, 3]
- stress (torch.Tensor): Stress tensor with shape [batch_size, 3, 3],
if compute_stress is True
- Return type:
Notes
The state is automatically transferred to the model’s device if needed. All output tensors are detached from the computation graph.