torch_sim.autobatchingΒΆ
Automatic batching and GPU memory management.
This module provides utilities for efficient batch processing of simulation states by dynamically determining optimal batch sizes based on GPU memory constraints. It includes tools for memory usage estimation, batch size determination, and two complementary strategies for batching: binning and hot-swapping.
Example
Using BinningAutoBatcher with a model:
batcher = BinningAutoBatcher(model, memory_scales_with="n_atoms")
batcher.load_states(states)
final_states = []
for batch in batcher:
final_states.append(evolve_batch(batch))
final_states = batcher.restore_original_order(final_states)
Notes
Memory scaling estimates are approximate and may need tuning for specific model architectures and GPU configurations.
Functions
Calculate a metric that estimates memory requirements for a state. |
|
Determine maximum batch size that fits in GPU memory. |
|
Estimate maximum memory scaling metric that fits in GPU memory. |
|
Measure peak GPU memory usage during a model's forward pass. |
|
Distribute items into bins of fixed maximum volume. |
Classes
Batcher that groups states into bins of similar computational cost. |
|
Batcher that dynamically swaps states based on convergence. |