Source code for torch_sim.monte_carlo

"""Propagators for Monte Carlo simulations.

This module provides functionality for performing Monte Carlo simulations,
particularly focused on swap Monte Carlo for atomic systems. It includes
implementations of the Metropolis criterion, swap generation, and utility
functions for handling permutations in batched systems.

The `swap_monte_carlo` function can be used with `integrate` but if
a trajectory is being reported, the `TorchSimTrajectory.write_state` method
must be called with `variable_masses=True`.
"""

from collections.abc import Callable
from dataclasses import dataclass

import torch

from torch_sim.state import SimState


[docs] @dataclass class SwapMCState(SimState): """State for Monte Carlo simulations with swap moves. This class extends the SimState to include properties specific to Monte Carlo simulations, such as the system energy and records of permutations applied during the simulation. Attributes: energy (torch.Tensor): Energy of the system with shape [batch_size] last_permutation (torch.Tensor): Last permutation applied to the system, with shape [n_atoms], tracking the moves made for analysis or reversal """ energy: torch.Tensor last_permutation: torch.Tensor
[docs] def generate_swaps( state: SimState, generator: torch.Generator | None = None ) -> torch.Tensor: """Generate atom swaps for a given batched system. Generates proposed swaps between atoms of different types within the same batch. The function ensures that swaps only occur between atoms with different atomic numbers. Args: state (SimState): The simulation state generator (torch.Generator | None, optional): Random number generator for reproducibility. Defaults to None. Returns: torch.Tensor: A tensor of proposed swaps with shape [n_batches, 2], where each row contains indices of atoms to be swapped """ batch = state.batch atomic_numbers = state.atomic_numbers batch_lengths = batch.bincount() # change batch_lengths to batch batch = torch.repeat_interleave( torch.arange(len(batch_lengths), device=batch.device), batch_lengths ) # Create ragged weights tensor without loops max_length = torch.max(batch_lengths).item() n_batches = len(batch_lengths) # Create a range tensor for each batch range_tensor = torch.arange(max_length, device=batch.device).expand( n_batches, max_length ) # Create a mask where values are less than the batch length batch_lengths_expanded = batch_lengths.unsqueeze(1).expand(n_batches, max_length) weights = (range_tensor < batch_lengths_expanded).float() first_index = torch.multinomial(weights, 1, replacement=False, generator=generator) # Process each batch - we need this loop because of ragged batches batch_starts = batch_lengths.cumsum(dim=0) - batch_lengths[0] for b in range(n_batches): # Get global index of selected atom first_idx = first_index[b, 0].item() + batch_starts[b].item() first_type = atomic_numbers[first_idx] # Get indices of atoms in this batch batch_start = batch_starts[b].item() batch_end = batch_start + batch_lengths[b].item() # Create mask for same-type atoms same_type = atomic_numbers[batch_start:batch_end] == first_type # Zero out weights for same-type atoms (accounting for padding) weights[b, : len(same_type)][same_type] = 0.0 second_index = torch.multinomial(weights, 1, replacement=False, generator=generator) zeroed_swaps = torch.concatenate([first_index, second_index], dim=1) return zeroed_swaps + (batch_lengths.cumsum(dim=0) - batch_lengths[0]).unsqueeze(1)
[docs] def swaps_to_permutation(swaps: torch.Tensor, n_atoms: int) -> torch.Tensor: """Convert atom swap pairs to a full permutation tensor. Creates a permutation tensor that represents the result of applying the specified swaps to the system. Args: swaps (torch.Tensor): Tensor of shape [n_swaps, 2] containing pairs of indices to swap n_atoms (int): Total number of atoms in the system Returns: torch.Tensor: Permutation tensor of shape [n_atoms] where permutation[i] contains the index of the atom that should be moved to position i """ permutation = torch.arange(n_atoms, device=swaps.device) permutation[swaps[:, 0]] = swaps[:, 1] permutation[swaps[:, 1]] = swaps[:, 0] return permutation
[docs] def validate_permutation(permutation: torch.Tensor, batch: torch.Tensor) -> None: """Validate that permutations only swap atoms within the same batch. Confirms that no swaps are attempted between atoms in different batches, which would lead to physically invalid configurations. Args: permutation (torch.Tensor): Permutation tensor of shape [n_atoms] batch (torch.Tensor): Batch assignments for each atom of shape [n_atoms] Raises: ValueError: If any swaps are between atoms in different batches """ if not torch.all(batch == batch[permutation]): raise ValueError("Swaps must be between atoms in the same batch")
[docs] def metropolis_criterion( energy_new: torch.Tensor, energy_old: torch.Tensor, kT: float, generator: torch.Generator | None = None, ) -> torch.Tensor: """Apply the Metropolis acceptance criterion for Monte Carlo moves. Determines whether proposed moves should be accepted or rejected based on the energy difference and system temperature, following the Boltzmann distribution. Args: energy_new (torch.Tensor): New energy after proposed move of shape [batch_size] energy_old (torch.Tensor): Old energy before proposed move of shape [batch_size] kT (float): Temperature of the system in energy units generator (torch.Generator | None, optional): Random number generator for reproducibility. Defaults to None. Returns: torch.Tensor: Boolean tensor of shape [batch_size] indicating acceptance (True) or rejection (False) for each move Notes: The acceptance probability follows min(1, exp(-ΔE/kT)) according to the standard Metropolis algorithm. """ delta_e = energy_new - energy_old # Calculate acceptance probability: min(1, exp(-ΔE/kT)) p_acceptance = torch.exp(-delta_e / kT) # Generate random numbers between 0 and 1 using the generator random_values = torch.rand( p_acceptance.shape, generator=generator, device=p_acceptance.device ) # Accept if random value < acceptance probability return random_values < p_acceptance
[docs] def swap_monte_carlo( *, model: torch.nn.Module, kT: float, seed: int | None = None, ) -> tuple[ Callable[[SimState], SwapMCState], Callable[[SwapMCState, float, torch.Generator | None], SwapMCState], ]: """Initialize a swap Monte Carlo simulation for atomic structure optimization. Creates and returns functions for initializing the Monte Carlo state and performing Monte Carlo steps. The simulation uses the Metropolis criterion to accept or reject proposed swaps based on energy differences. Make sure that if the trajectory is being reported, the `TorchSimTrajectory.write_state` method is called with `variable_masses=True`. Args: model (torch.nn.Module): Energy model that takes a SimState and returns a dict containing 'energy' as a key kT (float): Temperature of the system in energy units seed (int | None, optional): Seed for the random number generator. Defaults to None. Returns: tuple: A tuple containing: - init_function (Callable): Function to initialize a SwapMCState from a SimState - step_function (Callable): Function to perform a single Monte Carlo step Examples: >>> init_fn, step_fn = swap_monte_carlo(model=energy_model, kT=0.1, seed=42) >>> mc_state = init_fn(initial_state) >>> for _ in range(100): >>> mc_state = step_fn(mc_state) """ if seed is not None: generator = torch.Generator(device=model.device) generator.manual_seed(seed) else: generator = None def init_swap_mc_state(state: SimState) -> SwapMCState: model_output = model(state) return SwapMCState( positions=state.positions, masses=state.masses, cell=state.cell, pbc=state.pbc, atomic_numbers=state.atomic_numbers, batch=state.batch, energy=model_output["energy"], last_permutation=torch.arange(state.n_atoms, device=state.device), ) def swap_monte_carlo_step( state: SwapMCState, kT: float = kT, generator: torch.Generator | None = generator, ) -> SwapMCState: """Perform a single swap Monte Carlo step. Proposes atom swaps, evaluates the energy change, and uses the Metropolis criterion to determine whether to accept the move. Rejected moves are reversed. Args: state (SwapMCState): The current Monte Carlo state kT (float, optional): Temperature parameter in energy units. Defaults to the value specified in the outer function. generator (torch.Generator | None, optional): Random number generator. Defaults to None. Returns: SwapMCState: Updated Monte Carlo state after applying the step Notes: The function handles batched systems and ensures that swaps only occur within the same batch. """ swaps = generate_swaps(state, generator=generator) permutation = swaps_to_permutation(swaps, state.n_atoms) validate_permutation(permutation, state.batch) energies_old = state.energy.clone() state.positions = state.positions[permutation].clone() model_output = model(state) energies_new = model_output["energy"] accepted = metropolis_criterion( energies_new, energies_old, kT, generator=generator ) rejected_swaps = swaps[~accepted] reverse_rejected_swaps = swaps_to_permutation(rejected_swaps, state.n_atoms) state.positions = state.positions[reverse_rejected_swaps] state.energy = torch.where(accepted, energies_new, energies_old) state.last_permutation = permutation[reverse_rejected_swaps].clone() return state return init_swap_mc_state, swap_monte_carlo_step