BinningAutoBatcher¶
- class torch_sim.autobatching.BinningAutoBatcher(model, *, memory_scales_with='n_atoms_x_density', max_memory_scaler=None, return_indices=False, max_atoms_to_try=500_000, memory_scaling_factor=1.6, max_memory_padding=1.0)[source]¶
Bases:
object
Batcher that groups states into bins of similar computational cost.
Divides a collection of states into batches that can be processed efficiently without exceeding GPU memory. States are grouped based on a memory scaling metric to maximize GPU utilization. This approach is ideal for scenarios where all states need to be evolved the same number of steps.
To avoid a slow memory estimation step, set the max_memory_scaler to a known value.
- Variables:
model (ModelInterface) – Model used for memory estimation and processing.
memory_scales_with (str) – Metric type used for memory estimation.
max_memory_scaler (float) – Maximum memory metric allowed per batch.
max_atoms_to_try (int) – Maximum number of atoms to try when estimating memory.
return_indices (bool) – Whether to return original indices with batches.
state_slices (list[SimState]) – Individual states to be batched.
memory_scalers (list[float]) – Memory scaling metrics for each state.
index_to_scaler (dict) – Mapping from state index to its scaling metric.
index_bins (list[list[int]]) – Groups of state indices that can be batched together.
batched_states (list[list[SimState]]) – Grouped states ready for batching.
current_state_bin (int) – Index of the current batch being processed.
- Parameters:
Example:
# Create a batcher with a Lennard-Jones model batcher = BinningAutoBatcher( model=lj_model, memory_scales_with="n_atoms", max_memory_scaler=1000.0 ) # Load states and process them in batches batcher.load_states(states) final_states = [] for batch in batcher: final_states.append(evolve_batch(batch)) # Restore original order ordered_final_states = batcher.restore_original_order(final_states)
- load_states(states)[source]¶
Load new states into the batcher.
Processes the input states, computes memory scaling metrics for each, and organizes them into optimal batches using a bin-packing algorithm to maximize GPU utilization.
- Parameters:
states (list[SimState] | SimState) – Collection of states to batch. Either a list of individual SimState objects or a single batched SimState that will be split into individual states. Each SimState has shape information specific to its instance.
- Returns:
Maximum memory scaling metric that fits in GPU memory.
- Return type:
- Raises:
ValueError – If any individual state has a memory scaling metric greater than the maximum allowed value.
Example:
# Load individual states batcher.load_states([state1, state2, state3]) # Or load a batched state that will be split batcher.load_states(batched_state)
Notes
This method resets the current state bin index, so any ongoing iteration will be restarted when this method is called.
- next_batch(*, return_indices=False)[source]¶
Get the next batch of states.
Returns batches sequentially until all states have been processed. Each batch contains states grouped together to maximize GPU utilization without exceeding memory constraints.
- Parameters:
return_indices (bool) – Whether to return original indices along with the batch. Overrides the value set during initialization. Defaults to False.
- Returns:
If return_indices is False: A concatenated SimState containing the next batch of states, or None if no more batches.
If return_indices is True: Tuple of (concatenated SimState, indices), where indices are the original positions of the states, or None if no more batches.
- Return type:
Example:
# Get batches one by one all_converged_state, convergence = [], None while (result := batcher.next_batch(state, convergence))[0] is not None: state, converged_states = result all_converged_states.extend(converged_states) evolve_batch(state) convergence = convergence_criterion(state) else: all_converged_states.extend(result[1])
- restore_original_order(batched_states)[source]¶
Reorder processed states back to their original sequence.
Takes states that were processed in batches and restores them to the original order they were provided in. This is essential after batch processing to ensure results correspond to the input states.
- Parameters:
batched_states (list[SimState]) – State batches to reorder. These can be either concatenated batch states that will be split, or already split individual states.
- Returns:
- States in their original order, with shape information
matching the original input states.
- Return type:
- Raises:
ValueError – If the number of states doesn’t match the number of original indices.
Example:
# Process batches and restore original order results = [] for batch in batcher: results.append(process_batch(batch)) ordered_results = batcher.restore_original_order(results)