ModelInterface¶
- class torch_sim.models.interface.ModelInterface(model=None, device=None, dtype=torch.float64, **kwargs)[source]¶
Bases:
ABC
Abstract base class for all simulation models in torchsim.
This interface provides a common structure for all energy and force models, ensuring they implement the required methods and properties. It defines how models should process atomic positions and system information to compute energies, forces, and stresses.
- Variables:
device (device) – Device where the model runs computations.
dtype (dtype) – Data type used for tensor calculations.
compute_stress (bool) – Whether the model calculates stress tensors.
compute_forces (bool) – Whether the model calculates atomic forces.
memory_scales_with (MemoryScaling) – The metric that the model scales with. “n_atoms” uses only atom count and is suitable for models that have a fixed number of neighbors. “n_atoms_x_density” uses atom count multiplied by number density and is better for models with radial cutoffs. Defaults to “n_atoms_x_density”.
- Parameters:
- Return type:
Self
Examples
```py # Using a model that implements ModelInterface model = LennardJonesModel(device=torch.device(“cuda”))
# Forward pass with a simulation state output = model(sim_state)
# Access computed properties energy = output[“energy”] # Shape: [n_batches] forces = output[“forces”] # Shape: [n_atoms, 3] stress = output[“stress”] # Shape: [n_batches, 3, 3] ```
- property memory_scales_with: Literal['n_atoms_x_density', 'n_atoms']¶
The metric that the model scales with.
Models with radial neighbor cutoffs scale with “n_atoms_x_density”, while models with a fixed number of neighbors scale with “n_atoms”. Default is “n_atoms_x_density” because most models are radial cutoff based.
- abstract forward(state, **kwargs)[source]¶
Calculate energies, forces, and stresses for a atomistic system.
This is the main computational method that all model implementations must provide. It takes atomic positions and system information as input and returns a dictionary containing computed physical properties.
- Parameters:
state (SimState | StateDict) – Simulation state or state dictionary. The state dictionary is dependent on the model but typically must contain the following keys: - “positions”: Atomic positions with shape [n_atoms, 3] - “cell”: Unit cell vectors with shape [n_batches, 3, 3] - “batch”: Batch indices for each atom with shape [n_atoms] - “atomic_numbers”: Atomic numbers with shape [n_atoms] (optional)
**kwargs – Additional model-specific parameters.
- Returns:
- Dictionary containing computed properties:
”energy”: Potential energy with shape [n_batches]
”forces”: Atomic forces with shape [n_atoms, 3]
- ”stress”: Stress tensor with shape [n_batches, 3, 3] (if
compute_stress=True)
May include additional model-specific outputs
- Return type:
Examples
```py # Compute energies and forces with a model output = model.forward(state)
energy = output[“energy”] forces = output[“forces”] stress = output.get(“stress”, None) ```