torch_sim.models.interfaceΒΆ

Core interfaces for all models in torchsim.

This module defines the abstract base class that all torchsim models must implement. It establishes a common API for interacting with different force and energy models, ensuring consistent behavior regardless of the underlying implementation. The module also provides validation utilities to verify model conformance to the interface.

Example:

# Creating a custom model that implements the interface
class MyModel(ModelInterface):
    def __init__(self, device=None, dtype=torch.float64):
        self._device = device or torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        )
        self._dtype = dtype
        self._compute_stress = True
        self._compute_forces = True

    def forward(self, positions, cell, batch, atomic_numbers=None, **kwargs):
        # Implementation that returns energy, forces, and stress
        return {"energy": energy, "forces": forces, "stress": stress}

Notes

Models must explicitly declare support for stress computation through the compute_stress property, as some integrators require stress calculations.

Functions

validate_model_outputs

Validate the outputs of a model implementation against the interface requirements.

Classes

ModelInterface

Abstract base class for all simulation models in torchsim.