torch_sim.autobatchingΒΆ

Automatic batching and GPU memory management.

This module provides utilities for efficient batch processing of simulation states by dynamically determining optimal batch sizes based on GPU memory constraints. It includes tools for memory usage estimation, batch size determination, and two complementary strategies for batching: binning and hot-swapping.

Example

Using BinningAutoBatcher with a model:

batcher = BinningAutoBatcher(model, memory_scales_with="n_atoms")
batcher.load_states(states)
final_states = []
for batch in batcher:
    final_states.append(evolve_batch(batch))
final_states = batcher.restore_original_order(final_states)

Notes

Memory scaling estimates are approximate and may need tuning for specific model architectures and GPU configurations.

Functions

calculate_memory_scaler

Calculate a metric that estimates memory requirements for a state.

determine_max_batch_size

Determine maximum batch size that fits in GPU memory.

estimate_max_memory_scaler

Estimate maximum memory scaling metric that fits in GPU memory.

measure_model_memory_forward

Measure peak GPU memory usage during a model's forward pass.

to_constant_volume_bins

Distribute items into bins of fixed maximum volume.

Classes

BinningAutoBatcher

Batcher that groups states into bins of similar computational cost.

InFlightAutoBatcher

Batcher that dynamically swaps states based on convergence.