Source code for torch_sim.properties.correlations

"""Correlation function calculators for time series data.

Module provides efficient calculator for time correlation functions [1]_,
including both autocorrelation and cross-correlation functionality.
Leverages FFT-based methods [2]_ for performance and supports both CPU and
GPU acceleration through PyTorch.

The ``CorrelationCalculator`` class provides on-the-fly
correlation calculations during simulation runs, and a ``CircularBuffer``
utility class assists in data storage without frequent reallocations.

The ``VelocityAutoCorrelation`` class provides an interface for
computing the velocity autocorrelation functions (VACF).

References:
    .. [1] D. Frenkel and B. Smit, "Understanding molecular simulation: From
       algorithms to applications", Academic Press, 2002.
    .. [2] `pwtools: Phonon DOS <https://elcorto.github.io/pwtools/written/background/phonon_dos.html>`_

"""

from collections.abc import Callable
from typing import Any

import torch

from torch_sim.state import SimState


[docs] class CircularBuffer: """Circular buffer for storing time series data. Provides a fixed-size circular buffer optimized for storing and retrieving time series data, with minimal memory allocation. Attributes: size: Maximum number of elements to store buffer: Storage for the data head: Current write position count: Number of elements currently stored device: Device where the buffer is stored """ def __init__(self, size: int, device: torch.device | None = None) -> None: """Initialize a circular buffer. Args: size: Maximum number of elements to store device: Device for tensor storage (CPU or GPU) """ self.size = size self.buffer: torch.Tensor | None = None self.head = 0 self.count = 0 self.device = device
[docs] def append(self, value: torch.Tensor) -> None: """Append a new value to the buffer. Args: value: New tensor to store """ if self.buffer is None: # Initialize buffer shape as first value shape = (self.size, *value.shape) self.buffer = torch.zeros(shape, device=self.device, dtype=value.dtype) if self.buffer is not None: self.buffer[self.head] = value self.head = (self.head + 1) % self.size self.count = min(self.count + 1, self.size)
[docs] def get_array(self) -> torch.Tensor: """Get the current buffer contents as a tensor. Returns: Tensor containing the buffered data in chron. order """ if self.count == 0 or self.buffer is None: return torch.empty(0, device=self.device) if self.count < self.size: # Filled portion only! return self.buffer[: self.count] # Chronological order # Avoid unnecessary copy if unwrapped if self.head == 0: return self.buffer return torch.cat([self.buffer[self.head :], self.buffer[: self.head]])
@property def is_full(self) -> bool: """Check if the buffer is full.""" return self.count == self.size
[docs] class CorrelationCalculator: """Efficient on-the-fly correlation function calculator. Manage the calculation of time correlation functions during simulation, with support for both autocorrelation and cross-correlation of arbitrary properties. It maintains a sliding window of historical data and performs efficient updates. Attributes: window_size: Number of steps to keep in memory properties: Map of property names to their calculators buffers: Circular buffers for storing historical data correlations: Current correlation results device: Device where calculations are performed Example: Computing correlation function in loop:: corr_calc = CorrelationCalculator( window_size=100, properties={"velocity": lambda state: state.velocities}, ) for step in range(n_steps): state = integrator.step(state) # Call update at desired frequency if step % 10 == 0: # Sample every 10 steps corr_calc.update(state) # Periodically retrieve correlation functions if step % 1000 == 0: acfs = corr_calc.get_auto_correlations() # Process or save acfs... """ def __init__( self, *, window_size: int, properties: dict[str, Callable[[SimState], torch.Tensor]], device: torch.device, normalize: bool = True, ) -> None: """Initialize a correlation calculator. Args: window_size: Number of steps to keep in memory properties: Dictionary mapping names to functions that calculate properties from a SimState device: Device for tensor storage and computation normalize: Whether to normalize correlation functions to [0,1] """ self.window_size = window_size self.properties = properties or {} self.device = device self.normalize = normalize self.buffers = { name: CircularBuffer(window_size, device=device) for name in self.properties } self.correlations: dict[str, torch.Tensor] = {} self.cross_correlations: dict[tuple[str, str], torch.Tensor] = {}
[docs] def add_property( self, name: str, calculator: Callable[[SimState], torch.Tensor] ) -> None: """Track a new simulation property. Args: name: Name of the property calculator: Function that calculates property from a SimState """ if name in self.properties: raise ValueError(f"Property {name} already exists") self.properties[name] = calculator self.buffers[name] = CircularBuffer(self.window_size, device=self.device)
[docs] def update(self, state: SimState) -> None: """Update correlation calculations with new state data. Args: state: Current simulation state """ # Single pass update buffer_count = 0 buffer_total = len(self.buffers) for name, calc in self.properties.items(): value = calc(state) self.buffers[name].append(value) if self.buffers[name].count > 1: buffer_count += 1 # Correlations if we have enough data if buffer_count == buffer_total and buffer_total > 0: self._compute_correlations()
def _compute_correlations(self) -> None: # noqa: C901, PLR0915 """Compute correlation functions using FFT for efficiency.""" # Autocorrelations for name, buf in self.buffers.items(): data = buf.get_array() if len(data) == 0: continue original_shape = data.shape # Reshape to [time_steps, flattened_dim] if len(original_shape) > 1: data = data.reshape(original_shape[0], -1) n_dims = data.shape[1] if len(data.shape) > 1 else 1 if n_dims > 1: # Pre-allocate/Precompute acf = torch.zeros( (original_shape[0], n_dims), device=self.device, dtype=data.dtype ) data_centered = data - data.mean(dim=0, keepdim=True) if data_centered.shape[1] <= 128: # Batch Threshold # Transpose for batch FFT (dimensions become batch) data_batch = data_centered.T # Shape: [n_dims, time_steps] # Batch FFT processing n_fft = 2 * data_batch.shape[1] fft_batch = torch.fft.rfft(data_batch, n=n_fft) power_batch = torch.abs(fft_batch) ** 2 corr_batch = torch.fft.irfft(power_batch)[:, : data_batch.shape[1]] corr_batch = corr_batch.T # Shape: [time_steps, n_dims] if self.normalize: norms = corr_batch[0].clone() mask = norms > 1e-10 if mask.any(): corr_batch[:, mask] = corr_batch[:, mask] / norms[ mask ].unsqueeze(0) acf = corr_batch.reshape(original_shape) else: # Fallback for very high-dimensional data for i in range(n_dims): dim_data = data_centered[:, i] # FFT: n=2*len for zero-padding n_fft = 2 * len(dim_data) fft = torch.fft.rfft(dim_data, n=n_fft) power = torch.abs(fft) ** 2 corr = torch.fft.irfft(power)[: len(dim_data)] if self.normalize and corr[0] > 1e-10: corr = corr / corr[0] acf[:, i] = corr # Reshape back to match input dimensions acf = acf.reshape(original_shape) else: # Single dimension case dim_data = data - data.mean() n_fft = 2 * len(dim_data) fft = torch.fft.rfft(dim_data, n=n_fft) power = torch.abs(fft) ** 2 corr = torch.fft.irfft(power)[: len(dim_data)] if self.normalize and corr[0] > 1e-10: corr = corr / corr[0] acf = corr self.correlations[name] = acf # Cross-correlations names = list(self.buffers.keys()) for i, name1 in enumerate(names): for name2 in names[i + 1 :]: data1 = self.buffers[name1].get_array() data2 = self.buffers[name2].get_array() if len(data1) == 0 or len(data2) == 0: continue min_len = min(len(data1), len(data2)) data1 = data1[:min_len] data2 = data2[:min_len] # Multidimensional data if len(data1.shape) > 1 or len(data2.shape) > 1: # For now, simplify by taking mean across dimensions if len(data1.shape) > 1: # More efficient with tuple unpacking non_time_dims = tuple(range(1, len(data1.shape))) data1 = torch.mean(data1, dim=non_time_dims) if len(data2.shape) > 1: non_time_dims = tuple(range(1, len(data2.shape))) data2 = torch.mean(data2, dim=non_time_dims) # Center data data1 = data1 - data1.mean() data2 = data2 - data2.mean() n_fft = 2 * min_len fft1 = torch.fft.rfft(data1, n=n_fft) fft2 = torch.fft.rfft(data2, n=n_fft) ccf = torch.fft.irfft(fft1 * fft2.conj())[:min_len] if self.normalize and torch.abs(ccf[0]) > 1e-10: ccf = ccf / ccf[0] self.cross_correlations[(name1, name2)] = ccf
[docs] def get_auto_correlations(self) -> dict[str, torch.Tensor]: """Get autocorrelation results. Returns: Dictionary mapping property names to their correlation tensors """ return self.correlations
[docs] def get_cross_correlations(self) -> dict[tuple[str, str], torch.Tensor]: """Get cross-correlation results. Returns: Dictionary mapping pairs of property names to their cross-correlation tensors """ return self.cross_correlations
[docs] def reset(self) -> None: """Reset all buffers and correlations.""" self.buffers = { name: CircularBuffer(self.window_size, device=self.device) for name in self.properties } self.correlations = {} self.cross_correlations = {}
[docs] def to(self, device: torch.device) -> "CorrelationCalculator": """Move calculator to specified device. Args: device: Target device Returns: Self, for method chaining """ # Skip if already on target device if self.device == device: return self self.device = device new_buffers = {} for name, buf in self.buffers.items(): new_buf = CircularBuffer(self.window_size, device=device) if buf.buffer is not None: data = buf.get_array().to(device) # Larger buffers use a batch if len(data) > 100: # Balances memory transfer batch_size = 20 for i in range(0, len(data), batch_size): batch = data[i : min(i + batch_size, len(data))] for j in range(len(batch)): new_buf.append(batch[j]) else: for i in range(len(data)): new_buf.append(data[i]) new_buffers[name] = new_buf self.buffers = new_buffers # Move correlations if self.correlations: self.correlations = { name: corr.to(device) for name, corr in self.correlations.items() } if self.cross_correlations: self.cross_correlations = { names: corr.to(device) for names, corr in self.cross_correlations.items() } return self
[docs] class VelocityAutoCorrelation: """Calculator for velocity autocorrelation function (VACF). Computes VACF by averaging over atoms and dimensions, with optional running average across correlation windows. Using ``VelocityAutoCorrelation`` with :class:`~torch_sim.trajectory.TrajectoryReporter`:: # Create VACF calculator vacf_calc = VelocityAutoCorrelation( window_size=100, device=device, use_running_average=True, ) # Set up trajectory reporter reporter = TrajectoryReporter( "simulation_vacf.h5", state_frequency=100, prop_calculators={10: {"vacf": vacf_calc}}, ) """ def __init__( self, *, window_size: int, device: torch.device, use_running_average: bool = True, normalize: bool = True, ) -> None: """Initialize VACF calculator. Args: window_size: Number of steps in correlation window device: Computation device use_running_average: Whether to compute running average across windows normalize: Whether to normalize correlation functions to [0,1] """ self.corr_calc = CorrelationCalculator( window_size=window_size, properties={"velocity": lambda s: s.velocities}, device=device, normalize=normalize, ) self.use_running_average = use_running_average self._window_count = 0 self._avg = torch.zeros(window_size, device=device) def __call__(self, state: SimState, _: Any = None) -> torch.Tensor: """Update VACF with new state. Args: state: Current simulation state _: Unused model argument (required property calculator interface) Returns: Tensor containing average VACF """ self.corr_calc.update(state) if self.corr_calc.buffers["velocity"].count == self.corr_calc.window_size: correlations = self.corr_calc.get_auto_correlations() # dims: (natoms, ndims) vacf = torch.mean(correlations["velocity"], dim=(1, 2)) self._window_count += 1 if self.use_running_average: factor = 1.0 / self._window_count self._avg += (vacf - self._avg) * factor else: self._avg = vacf self.corr_calc.reset() return torch.tensor([self._window_count], device=state.device) @property def vacf(self) -> torch.Tensor | None: """Current VACF result.""" return self._avg