"""Core interfaces for all models in torchsim.
This module defines the abstract base class that all torchsim models must implement.
It establishes a common API for interacting with different force and energy models,
ensuring consistent behavior regardless of the underlying implementation. The module
also provides validation utilities to verify model conformance to the interface.
Example::
# Creating a custom model that implements the interface
class MyModel(ModelInterface):
def __init__(self, device=None, dtype=torch.float64):
self._device = device or torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
self._dtype = dtype
self._compute_stress = True
self._compute_forces = True
def forward(self, positions, cell, batch, atomic_numbers=None, **kwargs):
# Implementation that returns energy, forces, and stress
return {"energy": energy, "forces": forces, "stress": stress}
Notes:
Models must explicitly declare support for stress computation through the
compute_stress property, as some integrators require stress calculations.
"""
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Self
import torch
import torch_sim as ts
from torch_sim.state import SimState
from torch_sim.typing import MemoryScaling, StateDict
[docs]
class ModelInterface(ABC):
"""Abstract base class for all simulation models in torchsim.
This interface provides a common structure for all energy and force models,
ensuring they implement the required methods and properties. It defines how
models should process atomic positions and system information to compute energies,
forces, and stresses.
Attributes:
device (torch.device): Device where the model runs computations.
dtype (torch.dtype): Data type used for tensor calculations.
compute_stress (bool): Whether the model calculates stress tensors.
compute_forces (bool): Whether the model calculates atomic forces.
memory_scales_with (MemoryScaling): The metric
that the model scales with. "n_atoms" uses only atom count and is suitable
for models that have a fixed number of neighbors. "n_atoms_x_density" uses
atom count multiplied by number density and is better for models with
radial cutoffs. Defaults to "n_atoms_x_density".
Examples:
```py
# Using a model that implements ModelInterface
model = LennardJonesModel(device=torch.device("cuda"))
# Forward pass with a simulation state
output = model(sim_state)
# Access computed properties
energy = output["energy"] # Shape: [n_batches]
forces = output["forces"] # Shape: [n_atoms, 3]
stress = output["stress"] # Shape: [n_batches, 3, 3]
```
"""
@abstractmethod
def __init__(
self,
model: str | Path | torch.nn.Module | None = None,
device: torch.device | None = None,
dtype: torch.dtype = torch.float64,
**kwargs,
) -> Self:
"""Initialize a model implementation.
Implementations must set device, dtype and compute capability flags
to indicate what operations the model supports. Models may optionally
load parameters from a file or existing module.
Args:
model (str | Path | torch.nn.Module | None): Model specification, which
can be:
- Path to a model checkpoint or model file
- Pre-configured torch.nn.Module
- None for default initialization
Defaults to None.
device (torch.device | None): Device where the model will run. If None,
a default device will be selected. Defaults to None.
dtype (torch.dtype): Data type for model calculations. Defaults to
torch.float64.
**kwargs: Additional model-specific parameters.
Notes:
All implementing classes must set self._device, self._dtype,
self._compute_stress and self._compute_forces in their __init__ method.
"""
@property
def device(self) -> torch.device:
"""The device of the model."""
return self._device
@device.setter
def device(self, device: torch.device) -> None:
raise NotImplementedError(
"No device setter has been defined for this model"
" so the device cannot be changed after initialization."
)
@property
def dtype(self) -> torch.dtype:
"""The data type of the model."""
return self._dtype
@dtype.setter
def dtype(self, dtype: torch.dtype) -> None:
raise NotImplementedError(
"No dtype setter has been defined for this model"
" so the dtype cannot be changed after initialization."
)
@property
def compute_stress(self) -> bool:
"""Whether the model computes stresses."""
return self._compute_stress
@compute_stress.setter
def compute_stress(self, compute_stress: bool) -> None:
raise NotImplementedError(
"No compute_stress setter has been defined for this model"
" so compute_stress cannot be set after initialization."
)
@property
def compute_forces(self) -> bool:
"""Whether the model computes forces."""
return self._compute_forces
@compute_forces.setter
def compute_forces(self, compute_forces: bool) -> None:
raise NotImplementedError(
"No compute_forces setter has been defined for this model"
" so compute_forces cannot be set after initialization."
)
@property
def memory_scales_with(self) -> MemoryScaling:
"""The metric that the model scales with.
Models with radial neighbor cutoffs scale with "n_atoms_x_density",
while models with a fixed number of neighbors scale with "n_atoms".
Default is "n_atoms_x_density" because most models are radial cutoff based.
"""
return getattr(self, "_memory_scales_with", "n_atoms_x_density")
[docs]
@abstractmethod
def forward(self, state: SimState | StateDict, **kwargs) -> dict[str, torch.Tensor]:
"""Calculate energies, forces, and stresses for a atomistic system.
This is the main computational method that all model implementations must provide.
It takes atomic positions and system information as input and returns a dictionary
containing computed physical properties.
Args:
state (SimState | StateDict): Simulation state or state dictionary. The state
dictionary is dependent on the model but typically must contain the
following keys:
- "positions": Atomic positions with shape [n_atoms, 3]
- "cell": Unit cell vectors with shape [n_batches, 3, 3]
- "batch": Batch indices for each atom with shape [n_atoms]
- "atomic_numbers": Atomic numbers with shape [n_atoms] (optional)
**kwargs: Additional model-specific parameters.
Returns:
dict[str, torch.Tensor]: Dictionary containing computed properties:
- "energy": Potential energy with shape [n_batches]
- "forces": Atomic forces with shape [n_atoms, 3]
- "stress": Stress tensor with shape [n_batches, 3, 3] (if
compute_stress=True)
- May include additional model-specific outputs
Examples:
```py
# Compute energies and forces with a model
output = model.forward(state)
energy = output["energy"]
forces = output["forces"]
stress = output.get("stress", None)
```
"""
[docs]
def validate_model_outputs(
model: ModelInterface,
device: torch.device,
dtype: torch.dtype,
) -> None:
"""Validate the outputs of a model implementation against the interface requirements.
Runs a series of tests to ensure a model implementation correctly follows the
ModelInterface contract. The tests include creating sample systems, running
forward passes, and verifying output shapes and consistency.
Args:
model (ModelInterface): Model implementation to validate.
device (torch.device): Device to run the validation tests on.
dtype (torch.dtype): Data type to use for validation tensors.
Raises:
AssertionError: If the model doesn't conform to the required interface,
including issues with output shapes, types, or behavior consistency.
Example::
# Create a new model implementation
model = MyCustomModel(device=torch.device("cuda"))
# Validate that it correctly implements the interface
validate_model_outputs(model, device=torch.device("cuda"), dtype=torch.float64)
Notes:
This validator creates small test systems (silicon and iron) for validation.
It tests both single and multi-batch processing capabilities.
"""
from ase.build import bulk
assert model.dtype is not None
assert model.device is not None
assert model.compute_stress is not None
assert model.compute_forces is not None
try:
if not model.compute_stress:
model.compute_stress = True
stress_computed = True
except NotImplementedError:
stress_computed = False
try:
if not model.compute_forces:
model.compute_forces = True
force_computed = True
except NotImplementedError:
force_computed = False
si_atoms = bulk("Si", "diamond", a=5.43, cubic=True)
fe_atoms = bulk("Fe", "fcc", a=5.26, cubic=True).repeat([3, 1, 1])
sim_state = ts.io.atoms_to_state([si_atoms, fe_atoms], device, dtype)
og_positions = sim_state.positions.clone()
og_cell = sim_state.cell.clone()
og_batch = sim_state.batch.clone()
og_atomic_numbers = sim_state.atomic_numbers.clone()
model_output = model.forward(sim_state)
# assert model did not mutate the input
assert torch.allclose(og_positions, sim_state.positions)
assert torch.allclose(og_cell, sim_state.cell)
assert torch.allclose(og_batch, sim_state.batch)
assert torch.allclose(og_atomic_numbers, sim_state.atomic_numbers)
# assert model output has the correct keys
assert "energy" in model_output
assert "forces" in model_output if force_computed else True
assert "stress" in model_output if stress_computed else True
# assert model output shapes are correct
assert model_output["energy"].shape == (2,)
assert model_output["forces"].shape == (20, 3) if force_computed else True
assert model_output["stress"].shape == (2, 3, 3) if stress_computed else True
si_state = ts.io.atoms_to_state([si_atoms], device, dtype)
fe_state = ts.io.atoms_to_state([fe_atoms], device, dtype)
si_model_output = model.forward(si_state)
assert torch.allclose(
si_model_output["energy"], model_output["energy"][0], atol=10e-3
)
assert torch.allclose(
si_model_output["forces"],
model_output["forces"][: si_state.n_atoms],
atol=10e-3,
)
# assert torch.allclose(
# si_model_output["stress"],
# model_output["stress"][0],
# atol=10e-3,
# )
fe_model_output = model.forward(fe_state)
si_model_output = model.forward(si_state)
assert torch.allclose(
fe_model_output["energy"], model_output["energy"][1], atol=10e-2
)
assert torch.allclose(
fe_model_output["forces"],
model_output["forces"][si_state.n_atoms :],
atol=10e-2,
)
# assert torch.allclose(
# arr_model_output["stress"],
# model_output["stress"][1],
# atol=10e-3,
# )