determine_max_batch_size¶
- torch_sim.autobatching.determine_max_batch_size(state, model, max_atoms=500_000, start_size=1, scale_factor=1.6)[source]¶
Determine maximum batch size that fits in GPU memory.
Uses a geometric sequence to efficiently search for the largest number of batches that can be processed without running out of GPU memory. This function incrementally tests larger batch sizes until it encounters an out-of-memory error or reaches the specified maximum atom count.
- Parameters:
state (SimState) – State to replicate for testing.
model (ModelInterface) – Model to test with.
max_atoms (int) – Upper limit on number of atoms to try (for safety). Defaults to 500,000.
start_size (int) – Initial batch size to test. Defaults to 1.
scale_factor (float) – Factor to multiply batch size by in each iteration. Defaults to 1.3.
- Returns:
Maximum number of batches that fit in GPU memory.
- Return type:
- Raises:
RuntimeError – If any error other than CUDA out of memory occurs during testing.
Example:
# Find the maximum batch size for a Lennard-Jones model max_batches = determine_max_batch_size( state=sample_state, model=lj_model, max_atoms=100_000 )
Notes
The function returns a batch size slightly smaller than the actual maximum (with a safety margin) to avoid operating too close to memory limits.