Source code for torch_sim.models.soft_sphere

"""Soft sphere model for computing energies, forces and stresses.

This module provides implementations of soft sphere potentials for molecular dynamics
simulations. Soft sphere potentials are repulsive interatomic potentials that model
the core repulsion between atoms, avoiding the infinite repulsion of hard sphere models
while maintaining computational efficiency.

The soft sphere potential has the form:
    V(r) = epsilon * (sigma/r)^alpha

Where:

* r is the distance between particles
* sigma is the effective diameter of the particles
* epsilon controls the energy scale
* alpha determines the steepness of the repulsion (typically alpha >= 2)

Soft sphere models are particularly useful for:

* Granular matter simulations
* Modeling excluded volume effects
* Initial equilibration of dense systems
* Coarse-grained molecular dynamics

Example::

    # Create a soft sphere model with default parameters
    model = SoftSphereModel()

    # Calculate properties for a simulation state
    results = model(sim_state)
    energy = results["energy"]
    forces = results["forces"]

    # For multiple species with different interaction parameters
    multi_model = SoftSphereMultiModel(
        species=particle_types,
        sigma_matrix=size_matrix,
        epsilon_matrix=strength_matrix,
    )
    results = multi_model(sim_state)

"""

import torch

from torch_sim.models.interface import ModelInterface
from torch_sim.neighbors import vesin_nl_ts
from torch_sim.state import SimState
from torch_sim.transforms import get_pair_displacements
from torch_sim.typing import StateDict
from torch_sim.unbatched.models.soft_sphere import (
    soft_sphere_pair,
    soft_sphere_pair_force,
)


# Default parameter values defined at module level
DEFAULT_SIGMA = torch.tensor(1.0)
DEFAULT_EPSILON = torch.tensor(1.0)
DEFAULT_ALPHA = torch.tensor(2.0)


[docs] class SoftSphereModel(torch.nn.Module, ModelInterface): """Calculator for soft sphere potential energies and forces. Implements a model for computing properties based on the soft sphere potential, which describes purely repulsive interactions between particles. This potential is useful for modeling systems where particles should not overlap but don't have attractive interactions, such as granular materials and some colloidal systems. The potential energy between particles i and j is: V_ij(r) = epsilon * (sigma/r)^alpha Attributes: sigma (torch.Tensor): Effective particle diameter in distance units. epsilon (torch.Tensor): Energy scale parameter in energy units. alpha (torch.Tensor): Exponent controlling repulsion steepness (typically ≥ 2). cutoff (torch.Tensor): Cutoff distance for interactions. use_neighbor_list (bool): Whether to use neighbor list optimization. _device (torch.device): Computation device (CPU/GPU). _dtype (torch.dtype): Data type for tensor calculations. _compute_forces (bool): Whether to compute forces. _compute_stress (bool): Whether to compute stress tensor. per_atom_energies (bool): Whether to compute per-atom energy decomposition. per_atom_stresses (bool): Whether to compute per-atom stress decomposition. Examples: ```py # Basic usage with default parameters model = SoftSphereModel() results = model(sim_state) # Custom parameters for colloidal system colloid_model = SoftSphereModel( sigma=2.0, # particle diameter in nm epsilon=10.0, # energy scale in kJ/mol alpha=12.0, # steep repulsion for hard colloids compute_stress=True, ) # Get forces for a system with periodic boundary conditions results = colloid_model( SimState( positions=positions, cell=box_vectors, pbc=torch.tensor([True, True, True]), ) ) forces = results["forces"] # shape: [n_particles, 3] ``` """ def __init__( self, sigma: float = 1.0, epsilon: float = 1.0, alpha: float = 2.0, device: torch.device | None = None, dtype: torch.dtype = torch.float32, *, # Force keyword-only arguments compute_forces: bool = True, compute_stress: bool = False, per_atom_energies: bool = False, per_atom_stresses: bool = False, use_neighbor_list: bool = True, cutoff: float | None = None, ) -> None: """Initialize the soft sphere model. Creates a soft sphere model with specified parameters for particle interactions and computation options. Args: sigma (float): Effective particle diameter. Determines the distance scale of the interaction. Defaults to 1.0. epsilon (float): Energy scale parameter. Controls the strength of the repulsion. Defaults to 1.0. alpha (float): Exponent controlling repulsion steepness. Higher values create steeper, more hard-sphere-like repulsion. Defaults to 2.0. device (torch.device | None): Device for computations. If None, uses CPU. Defaults to None. dtype (torch.dtype): Data type for calculations. Defaults to torch.float32. compute_forces (bool): Whether to compute forces. Defaults to True. compute_stress (bool): Whether to compute stress tensor. Defaults to False. per_atom_energies (bool): Whether to compute per-atom energy decomposition. Defaults to False. per_atom_stresses (bool): Whether to compute per-atom stress decomposition. Defaults to False. use_neighbor_list (bool): Whether to use a neighbor list for optimization. Significantly faster for large systems. Defaults to True. cutoff (float | None): Cutoff distance for interactions. If None, uses the value of sigma. Defaults to None. Examples: ```py # Default model model = SoftSphereModel() # WCA-like repulsive potential (derived from Lennard-Jones) wca_model = SoftSphereModel( sigma=1.0, epsilon=1.0, alpha=12.0, # Steep repulsion similar to r^-12 term in LJ cutoff=2 ** (1 / 6), # WCA cutoff at minimum of LJ potential ) ``` """ super().__init__() self._device = device or torch.device("cpu") self._dtype = dtype self._compute_forces = compute_forces self._compute_stress = compute_stress self.per_atom_energies = per_atom_energies self.per_atom_stresses = per_atom_stresses self.use_neighbor_list = use_neighbor_list # Convert interaction parameters to tensors with proper dtype/device self.sigma = torch.tensor(sigma, dtype=dtype, device=self.device) self.cutoff = torch.tensor(cutoff or sigma, dtype=dtype, device=self.device) self.epsilon = torch.tensor(epsilon, dtype=dtype, device=self.device) self.alpha = torch.tensor(alpha, dtype=dtype, device=self.device)
[docs] def unbatched_forward( self, state: SimState, ) -> dict[str, torch.Tensor]: """Compute energies and forces for a single unbatched system. Internal implementation that processes a single, non-batched simulation state. This method handles the core computations for pair interactions, including neighbor list construction, distance calculations, and property computation. Args: state (SimState): Single, non-batched simulation state containing atomic positions, cell vectors, and other system information. Returns: dict[str, torch.Tensor]: Dictionary of computed properties: - "energy": Total potential energy (scalar) - "forces": Atomic forces with shape [n_atoms, 3] (if compute_forces=True) - "stress": Stress tensor with shape [3, 3] (if compute_stress=True) - "energies": Per-atom energies with shape [n_atoms] (if per_atom_energies=True) - "stresses": Per-atom stresses with shape [n_atoms, 3, 3] (if per_atom_stresses=True) Notes: This method can work with both neighbor list and full pairwise calculations. The soft sphere potential is purely repulsive, and forces are truncated at the cutoff distance. """ if isinstance(state, dict): state = SimState(**state, masses=torch.ones_like(state["positions"])) positions = state.positions cell = state.row_vector_cell cell = cell.squeeze() pbc = state.pbc if self.use_neighbor_list: # Get neighbor list using vesin_nl_ts mapping, shifts = vesin_nl_ts( positions=positions, cell=cell, pbc=pbc, cutoff=self.cutoff, sort_id=False, ) # Get displacements between neighbor pairs dr_vec, distances = get_pair_displacements( positions=positions, cell=cell, pbc=pbc, pairs=mapping, shifts=shifts, ) else: # Direct N^2 computation of all pairs dr_vec, distances = get_pair_displacements( positions=positions, cell=cell, pbc=pbc, ) # Remove self-interactions and apply cutoff mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self.device) distances = distances.masked_fill(mask, float("inf")) mask = distances < self.cutoff # Get valid pairs and their displacements i, j = torch.where(mask) mapping = torch.stack([j, i]) dr_vec = dr_vec[mask] distances = distances[mask] # Calculate pair energies using soft sphere potential pair_energies = soft_sphere_pair( distances, sigma=self.sigma, epsilon=self.epsilon, alpha=self.alpha ) # Initialize results with total energy (divide by 2 to avoid double counting) results = {"energy": 0.5 * pair_energies.sum()} if self.per_atom_energies: # Compute per-atom energy contributions atom_energies = torch.zeros( positions.shape[0], dtype=self.dtype, device=self.device ) # Each atom gets half of the pair energy atom_energies.index_add_(0, mapping[0], 0.5 * pair_energies) atom_energies.index_add_(0, mapping[1], 0.5 * pair_energies) results["energies"] = atom_energies if self.compute_forces or self.compute_stress: # Calculate pair forces pair_forces = soft_sphere_pair_force( distances, sigma=self.sigma, epsilon=self.epsilon, alpha=self.alpha ) # Project scalar forces onto displacement vectors force_vectors = (pair_forces / distances)[:, None] * dr_vec if self.compute_forces: # Compute atomic forces by accumulating pair contributions forces = torch.zeros_like(positions) # Add force contributions (f_ij on j, -f_ij on i) forces.index_add_(0, mapping[0], force_vectors) forces.index_add_(0, mapping[1], -force_vectors) results["forces"] = forces if self.compute_stress and cell is not None: # Compute stress tensor using virial formula stress_per_pair = torch.einsum("...i,...j->...ij", dr_vec, force_vectors) volume = torch.abs(torch.linalg.det(cell)) results["stress"] = -stress_per_pair.sum(dim=0) / volume if self.per_atom_stresses: # Compute per-atom stress contributions atom_stresses = torch.zeros( (positions.shape[0], 3, 3), dtype=self.dtype, device=self.device ) atom_stresses.index_add_(0, mapping[0], -0.5 * stress_per_pair) atom_stresses.index_add_(0, mapping[1], -0.5 * stress_per_pair) results["stresses"] = atom_stresses / volume return results
[docs] def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: """Compute soft sphere potential energies, forces, and stresses for a system. Main entry point for soft sphere potential calculations that handles batched states by dispatching each batch to the unbatched implementation and combining results. Args: state (SimState | StateDict): Input state containing atomic positions, cell vectors, and other system information. Can be a SimState object or a dictionary with the same keys. Returns: dict[str, torch.Tensor]: Dictionary of computed properties: - "energy": Potential energy with shape [n_batches] - "forces": Atomic forces with shape [n_atoms, 3] (if compute_forces=True) - "stress": Stress tensor with shape [n_batches, 3, 3] (if compute_stress=True) - May include additional outputs based on configuration Raises: ValueError: If batch cannot be inferred for multi-cell systems. Examples: ```py # Compute properties for a simulation state model = SoftSphereModel(compute_forces=True) results = model(sim_state) energy = results["energy"] # Shape: [n_batches] forces = results["forces"] # Shape: [n_atoms, 3] ``` """ if isinstance(state, dict): state = SimState(**state, masses=torch.ones_like(state["positions"])) # Handle batch indices if not provided if state.batch is None and state.cell.shape[0] > 1: raise ValueError("Batch can only be inferred for batch size 1.") outputs = [self.unbatched_forward(state[i]) for i in range(state.n_batches)] properties = outputs[0] # Combine results results = {} for key in ("stress", "energy"): if key in properties: results[key] = torch.stack([out[key] for out in outputs]) for key in ("forces", "energies", "stresses"): if key in properties: results[key] = torch.cat([out[key] for out in outputs], dim=0) return results
[docs] class SoftSphereMultiModel(torch.nn.Module): """Calculator for systems with multiple particle types. Extends the basic soft sphere model to support multiple particle types with different interaction parameters for each pair of particle types. This enables simulation of heterogeneous systems like mixtures, composites, or biomolecular systems with different interaction strengths between different components. This model maintains matrices of interaction parameters (sigma, epsilon, alpha) where each element [i,j] represents the parameter for interactions between particle types i and j. Attributes: species (torch.Tensor): Particle type indices for each particle in the system. sigma_matrix (torch.Tensor): Matrix of distance parameters for each pair of types. Shape: [n_types, n_types]. epsilon_matrix (torch.Tensor): Matrix of energy scale parameters for each pair. Shape: [n_types, n_types]. alpha_matrix (torch.Tensor): Matrix of exponents for each pair of types. Shape: [n_types, n_types]. cutoff (torch.Tensor): Maximum interaction distance. compute_forces (bool): Whether to compute forces. compute_stress (bool): Whether to compute stress tensor. per_atom_energies (bool): Whether to compute per-atom energy decomposition. per_atom_stresses (bool): Whether to compute per-atom stress decomposition. use_neighbor_list (bool): Whether to use neighbor list optimization. periodic (bool): Whether to use periodic boundary conditions. _device (torch.device): Computation device (CPU/GPU). _dtype (torch.dtype): Data type for tensor calculations. Examples: ```py # Create a binary mixture with different interaction parameters # Define interaction matrices (size 2x2 for binary system) sigma_matrix = torch.tensor( [ [1.0, 0.8], # Type 0-0 and 0-1 interactions [0.8, 0.6], # Type 1-0 and 1-1 interactions ] ) epsilon_matrix = torch.tensor( [ [1.0, 0.5], # Type 0-0 and 0-1 interactions [0.5, 2.0], # Type 1-0 and 1-1 interactions ] ) # Particle type assignments (0 or 1 for each particle) species = torch.tensor([0, 0, 1, 1, 0, 1]) # Create the model model = SoftSphereMultiModel( species=species, sigma_matrix=sigma_matrix, epsilon_matrix=epsilon_matrix, compute_forces=True, ) # Compute properties results = model(simulation_state) ``` """ def __init__( self, species: torch.Tensor | None = None, sigma_matrix: torch.Tensor | None = None, epsilon_matrix: torch.Tensor | None = None, alpha_matrix: torch.Tensor | None = None, device: torch.device | None = None, dtype: torch.dtype = torch.float32, *, # Force keyword-only arguments pbc: bool = True, compute_forces: bool = True, compute_stress: bool = False, per_atom_energies: bool = False, per_atom_stresses: bool = False, use_neighbor_list: bool = True, cutoff: float | None = None, ) -> None: """Initialize a soft sphere model for multi-component systems. Creates a model for systems with multiple particle types, each with potentially different interaction parameters. Args: species (torch.Tensor | None): Particle type indices, shape [n_particles]. Each value should be an integer in range [0, n_types-1]. If None, assumes all particles are the same type (0). Defaults to None. sigma_matrix (torch.Tensor | None): Matrix of distance parameters for each pair of types. Shape [n_types, n_types]. If None, uses default value 1.0 for all pairs. Defaults to None. epsilon_matrix (torch.Tensor | None): Matrix of energy scale parameters for each pair of types. Shape [n_types, n_types]. If None, uses default value 1.0 for all pairs. Defaults to None. alpha_matrix (torch.Tensor | None): Matrix of exponents for each pair. Shape [n_types, n_types]. If None, uses default value 2.0 for all pairs. Defaults to None. device (torch.device | None): Device for computations. If None, uses CPU. Defaults to None. dtype (torch.dtype): Data type for calculations. Defaults to torch.float32. pbc (bool): Whether to use periodic boundary conditions. Defaults to True. compute_forces (bool): Whether to compute forces. Defaults to True. compute_stress (bool): Whether to compute stress tensor. Defaults to False. per_atom_energies (bool): Whether to compute per-atom energy decomposition. Defaults to False. per_atom_stresses (bool): Whether to compute per-atom stress decomposition. Defaults to False. use_neighbor_list (bool): Whether to use a neighbor list for optimization. Defaults to True. cutoff (float | None): Cutoff distance for interactions. If None, uses the maximum value from sigma_matrix. Defaults to None. Examples: ```py # Binary polymer mixture with different interactions # Polymer A (type 0): larger, softer particles # Polymer B (type 1): smaller, harder particles # Create species assignment (100 particles total) species = torch.cat( [ torch.zeros(50, dtype=torch.long), # 50 particles of type 0 torch.ones(50, dtype=torch.long), # 50 particles of type 1 ] ) # Interaction matrices sigma = torch.tensor( [ [1.2, 1.0], # A-A and A-B interactions [1.0, 0.8], # B-A and B-B interactions ] ) epsilon = torch.tensor( [ [1.0, 1.5], # A-A and A-B interactions [1.5, 2.0], # B-A and B-B interactions ] ) # Create model with mixing rules model = SoftSphereMultiModel( species=species, sigma_matrix=sigma, epsilon_matrix=epsilon, compute_forces=True, ) ``` Notes: The interaction matrices must be symmetric for physical consistency (e.g., interaction of type 0 with type 1 should be the same as type 1 with type 0). """ super().__init__() self.device = device or torch.device("cpu") self.dtype = dtype self.pbc = pbc self.compute_forces = compute_forces self.compute_stress = compute_stress self.per_atom_energies = per_atom_energies self.per_atom_stresses = per_atom_stresses self.use_neighbor_list = use_neighbor_list # Store species list and determine number of unique species self.species = species n_species = len(torch.unique(species)) # Initialize parameter matrices with defaults if not provided default_sigma = DEFAULT_SIGMA.to(device=self.device, dtype=self.dtype) default_epsilon = DEFAULT_EPSILON.to(device=self.device, dtype=self.dtype) default_alpha = DEFAULT_ALPHA.to(device=self.device, dtype=self.dtype) # Validate matrix shapes match number of species if sigma_matrix is not None and sigma_matrix.shape != (n_species, n_species): raise ValueError(f"sigma_matrix must have shape ({n_species}, {n_species})") if epsilon_matrix is not None and epsilon_matrix.shape != ( n_species, n_species, ): raise ValueError(f"epsilon_matrix must have shape ({n_species}, {n_species})") if alpha_matrix is not None and alpha_matrix.shape != (n_species, n_species): raise ValueError(f"alpha_matrix must have shape ({n_species}, {n_species})") # Create parameter matrices, using defaults if not provided self.sigma_matrix = ( sigma_matrix if sigma_matrix is not None else default_sigma * torch.ones((n_species, n_species), dtype=dtype, device=device) ) self.epsilon_matrix = ( epsilon_matrix if epsilon_matrix is not None else default_epsilon * torch.ones((n_species, n_species), dtype=dtype, device=device) ) self.alpha_matrix = ( alpha_matrix if alpha_matrix is not None else default_alpha * torch.ones((n_species, n_species), dtype=dtype, device=device) ) # Ensure parameter matrices are symmetric (required for energy conservation) assert torch.allclose(self.sigma_matrix, self.sigma_matrix.T) assert torch.allclose(self.epsilon_matrix, self.epsilon_matrix.T) assert torch.allclose(self.alpha_matrix, self.alpha_matrix.T) # Set interaction cutoff distance self.cutoff = torch.tensor( cutoff or float(self.sigma_matrix.max()), dtype=dtype, device=device )
[docs] def unbatched_forward( # noqa: PLR0915 self, state: SimState, species: torch.Tensor | None = None, ) -> dict[str, torch.Tensor]: """Compute energies and forces for a single unbatched system with multiple species. Internal implementation that processes a single, non-batched simulation state. This method handles all pair interactions between particles of different types using the appropriate interaction parameters from the parameter matrices. Args: state (SimState): Single, non-batched simulation state containing atomic positions, cell vectors, and other system information. species (torch.Tensor | None): Optional species indices to override the ones provided during initialization. Shape: [n_particles]. If None, uses the species defined at initialization. Defaults to None. Returns: dict[str, torch.Tensor]: Dictionary of computed properties: - "energy": Total potential energy (scalar) - "forces": Atomic forces with shape [n_atoms, 3] (if compute_forces=True) - "stress": Stress tensor with shape [3, 3] (if compute_stress=True) - "energies": Per-atom energies with shape [n_atoms] (if per_atom_energies=True) - "stresses": Per-atom stresses with shape [n_atoms, 3, 3] (if per_atom_stresses=True) Notes: This method supports both neighbor list optimization and full pairwise calculations based on the use_neighbor_list parameter. For each pair of particles, it looks up the appropriate parameters based on the species of the two particles. """ # Convert inputs to proper device/dtype and handle species if not isinstance(state, SimState): state = SimState(**state) if species is not None: species = species.to(device=self.device, dtype=torch.long) else: species = self.species positions = state.positions cell = state.row_vector_cell cell = cell.squeeze() species_idx = species # Compute neighbor list or full distance matrix if self.use_neighbor_list: # Get neighbor list for efficient computation mapping, shifts = vesin_nl_ts( positions=positions, cell=cell, pbc=self.pbc, cutoff=self.cutoff, sorti=False, ) # Get displacements between neighbor pairs dr_vec, distances = get_pair_displacements( positions=positions, cell=cell, pbc=self.pbc, pairs=mapping, shifts=shifts, ) else: # Direct N^2 computation of all pairs dr_vec, distances = get_pair_displacements( positions=positions, cell=cell, pbc=self.pbc, ) # Remove self-interactions and apply cutoff mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self.device) distances = distances.masked_fill(mask, float("inf")) mask = distances < self.cutoff # Get valid pairs and their displacements i, j = torch.where(mask) mapping = torch.stack([j, i]) dr_vec = dr_vec[mask] distances = distances[mask] # Look up species-specific parameters for each interacting pair pair_species_1 = species_idx[mapping[0]] # Species of first atom in pair pair_species_2 = species_idx[mapping[1]] # Species of second atom in pair # Get interaction parameters from parameter matrices pair_sigmas = self.sigma_matrix[pair_species_1, pair_species_2] pair_epsilons = self.epsilon_matrix[pair_species_1, pair_species_2] pair_alphas = self.alpha_matrix[pair_species_1, pair_species_2] # Calculate pair energies using species-specific parameters pair_energies = soft_sphere_pair( distances, sigma=pair_sigmas, epsilon=pair_epsilons, alpha=pair_alphas ) # Initialize results with total energy (divide by 2 to avoid double counting) results = {"energy": 0.5 * pair_energies.sum()} if self.per_atom_energies: # Compute per-atom energy contributions atom_energies = torch.zeros( positions.shape[0], dtype=self.dtype, device=self.device ) # Each atom gets half of the pair energy atom_energies.index_add_(0, mapping[0], 0.5 * pair_energies) atom_energies.index_add_(0, mapping[1], 0.5 * pair_energies) results["energies"] = atom_energies if self.compute_forces or self.compute_stress: # Calculate pair forces pair_forces = soft_sphere_pair_force( distances, sigma=pair_sigmas, epsilon=pair_epsilons, alpha=pair_alphas ) # Project scalar forces onto displacement vectors force_vectors = (pair_forces / distances)[:, None] * dr_vec if self.compute_forces: # Compute atomic forces by accumulating pair contributions forces = torch.zeros_like(positions) # Add force contributions (f_ij on j, -f_ij on i) forces.index_add_(0, mapping[0], force_vectors) forces.index_add_(0, mapping[1], -force_vectors) results["forces"] = forces if self.compute_stress and cell is not None: # Compute stress tensor using virial formula stress_per_pair = torch.einsum("...i,...j->...ij", dr_vec, force_vectors) volume = torch.abs(torch.linalg.det(cell)) results["stress"] = -stress_per_pair.sum(dim=0) / volume if self.per_atom_stresses: # Compute per-atom stress contributions atom_stresses = torch.zeros( (positions.shape[0], 3, 3), dtype=self.dtype, device=self.device ) atom_stresses.index_add_(0, mapping[0], -0.5 * stress_per_pair) atom_stresses.index_add_(0, mapping[1], -0.5 * stress_per_pair) results["stresses"] = atom_stresses / volume return results
[docs] def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: """Compute soft sphere potential properties for multi-component systems. Main entry point for multi-species soft sphere calculations that handles batched states by dispatching each batch to the unbatched implementation and combining results. Args: state (SimState | StateDict): Input state containing atomic positions, cell vectors, and other system information. Can be a SimState object or a dictionary with the same keys. Returns: dict[str, torch.Tensor]: Dictionary of computed properties: - "energy": Potential energy with shape [n_batches] - "forces": Atomic forces with shape [n_atoms, 3] (if compute_forces=True) - "stress": Stress tensor with shape [n_batches, 3, 3] (if compute_stress=True) - May include additional outputs based on configuration Raises: ValueError: If batch cannot be inferred for multi-cell systems or if species information is missing. Examples: ```py # Create model for binary mixture model = SoftSphereMultiModel( species=particle_types, sigma_matrix=distance_matrix, epsilon_matrix=strength_matrix, compute_forces=True, ) # Calculate properties results = model(simulation_state) energy = results["energy"] forces = results["forces"] ``` Notes: This method requires species information either provided during initialization or included in the state object's metadata. """ if not isinstance(state, SimState): state = SimState( **state, pbc=self.pbc, masses=torch.ones_like(state["positions"]) ) elif state.pbc != self.pbc: raise ValueError("PBC mismatch between model and state") # Handle batch indices if not provided if state.batch is None and state.cell.shape[0] > 1: raise ValueError("Batch can only be inferred for batch size 1.") outputs = [self.unbatched_forward(state[i]) for i in range(state.n_batches)] properties = outputs[0] # Combine results results = {} for key in ("stress", "energy", "forces", "energies", "stresses"): if key in properties: results[key] = torch.stack([out[key] for out in outputs]) for key in ("forces", "energies", "stresses"): if key in properties: results[key] = torch.cat([out[key] for out in outputs], dim=0) return results