FairChemModel

class torch_sim.models.fairchem.FairChemModel(model, neighbor_list_fn=None, *, config_yml=None, model_name=None, local_cache=None, trainer=None, cpu=False, seed=None, dtype=None, compute_stress=False, pbc=True, disable_amp=True)[source]

Bases: Module, ModelInterface

Computes atomistic energies, forces and stresses using a FairChem model.

This class wraps a FairChem model to compute energies, forces, and stresses for atomistic systems. It handles model initialization, checkpoint loading, and provides a forward pass that accepts a SimState object and returns model predictions.

The model can be initialized either with a configuration file or a pretrained checkpoint. It supports various model architectures and configurations supported by FairChem.

Variables:
  • neighbor_list_fn (Callable | None) – Function to compute neighbor lists

  • config (dict) – Complete model configuration dictionary

  • trainer – FairChem trainer object that contains the model

  • data_object (Batch) – Data object containing system information

  • implemented_properties (list) – Model outputs the model can compute

  • pbc (bool) – Whether periodic boundary conditions are used

  • _dtype (dtype) – Data type used for computation

  • _compute_stress (bool) – Whether to compute stress tensor

  • _compute_forces (bool) – Whether to compute forces

  • _device (device) – Device where computation is performed

  • _reshaped_props (dict) – Properties that need reshaping after computation

Parameters:
  • model (str | Path | None)

  • neighbor_list_fn (Callable | None)

  • config_yml (str | None)

  • model_name (str | None)

  • local_cache (str | None)

  • trainer (str | None)

  • cpu (bool)

  • seed (int | None)

  • dtype (dtype | None)

  • compute_stress (bool)

  • pbc (bool)

  • disable_amp (bool)

Examples

>>> model = FairChemModel(model="path/to/checkpoint.pt", compute_stress=True)
>>> results = model(state)
load_checkpoint(checkpoint_path, checkpoint=None)[source]

Load an existing trained model checkpoint.

Loads model parameters from a checkpoint file or dictionary, setting the model to inference mode.

Parameters:
  • checkpoint_path (str) – Path to the trained model checkpoint file

  • checkpoint (dict | None) – A pretrained checkpoint dictionary. If provided, this dictionary is used instead of loading from checkpoint_path.

Return type:

None

Notes

If loading fails, a message is printed but no exception is raised.

forward(state)[source]

Perform forward pass to compute energies, forces, and other properties.

Takes a simulation state and computes the properties implemented by the model, such as energy, forces, and stresses.

Parameters:

state (SimState | StateDict) – State object containing positions, cells, atomic numbers, and other system information. If a dictionary is provided, it will be converted to a SimState.

Returns:

Dictionary of model predictions, which may include:
  • energy (torch.Tensor): Energy with shape [batch_size]

  • forces (torch.Tensor): Forces with shape [n_atoms, 3]

  • stress (torch.Tensor): Stress tensor with shape [batch_size, 3, 3],

    if compute_stress is True

Return type:

dict

Notes

The state is automatically transferred to the model’s device if needed. All output tensors are detached from the computation graph.