InFlightAutoBatcher¶
- class torch_sim.autobatching.InFlightAutoBatcher(model, *, memory_scales_with='n_atoms_x_density', max_memory_scaler=None, max_atoms_to_try=500_000, memory_scaling_factor=1.6, return_indices=False, max_iterations=None, max_memory_padding=1.0)[source]¶
Bases:
object
Batcher that dynamically swaps states based on convergence.
Optimizes GPU utilization by removing converged states from the batch and adding new states to process. This approach is ideal for iterative processes where different states may converge at different rates, such as geometry optimization.
To avoid a slow memory estimation step, set the max_memory_scaler to a known value.
- Variables:
model (ModelInterface) – Model used for memory estimation and processing.
memory_scales_with (str) – Metric type used for memory estimation.
max_memory_scaler (float) – Maximum memory metric allowed per batch.
max_atoms_to_try (int) – Maximum number of atoms to try when estimating memory.
return_indices (bool) – Whether to return original indices with batches.
max_iterations (int | None) – Maximum number of iterations per state.
state_slices (list[SimState]) – Individual states to be batched.
memory_scalers (list[float]) – Memory scaling metrics for each state.
current_idx (list[int]) – Indices of states in the current batch.
completed_idx (list[int]) – Indices of states that have been processed.
completed_idx_og_order (list[int]) – Original indices of completed states.
current_scalers (list[float]) – Memory metrics for states in current batch.
swap_attempts (dict[int, int]) – Count of iterations for each state.
- Parameters:
Example:
# Create a hot-swapping batcher batcher = InFlightAutoBatcher( model=lj_model, memory_scales_with="n_atoms", max_memory_scaler=1000.0 ) # Load states and process them with convergence checking batcher.load_states(states) batch, completed_states = batcher.next_batch(None, None) while batch is not None: # Process the batch batch = process_batch(batch) # Check convergence convergence = check_convergence(batch) # Get next batch, with converged states swapped out batch, new_completed = batcher.next_batch(batch, convergence) completed_states.extend(new_completed) # Restore original order ordered_results = batcher.restore_original_order(completed_states)
- load_states(states)[source]¶
Load new states into the batcher.
Processes the input states, computes memory scaling metrics for each, and prepares them for dynamic batching based on convergence criteria. Unlike BinningAutoBatcher, this doesn’t create fixed batches upfront.
- Parameters:
states (list[SimState] | Iterator[SimState] | SimState) – Collection of states to batch. Can be a list of individual SimState objects, an iterator yielding SimState objects, or a single batched SimState that will be split into individual states. Each SimState has shape information specific to its instance.
- Raises:
ValueError – If any individual state has a memory scaling metric greater than the maximum allowed value.
- Return type:
None
Example:
# Load individual states batcher.load_states([state1, state2, state3]) # Or load a batched state that will be split batcher.load_states(batched_state) # Or load states from an iterator batcher.load_states(state_generator())
Notes
This method resets the current state indices and completed state tracking, so any ongoing processing will be restarted when this method is called.
- next_batch(updated_state, convergence_tensor)[source]¶
Get the next batch of states based on convergence.
Removes converged states from the batch, adds new states if possible, and returns both the updated batch and the completed states. This method implements the core dynamic batching strategy of the InFlightAutoBatcher.
- Parameters:
updated_state (SimState | None) – Current state after processing, or None for the first call. Contains shape information specific to the SimState instance.
convergence_tensor (Tensor | None) – Boolean tensor with shape [n_batches] indicating which states have converged (True) or not (False). Should be None only for the first call.
- Returns:
- tuple[SimState | None, list[SimState]] | tuple[SimState | None,
list[SimState], list[int]]: - If return_indices is False: Tuple of (next_batch, completed_states)
where next_batch is a SimState or None if all states are processed, and completed_states is a list of SimState objects.
If return_indices is True: Tuple of (next_batch, completed_states, indices) where indices are the current batch’s positions.
- Raises:
AssertionError – If convergence_tensor doesn’t match the expected shape or if other validation checks fail.
- Return type:
tuple[SimState | None, list[SimState]] | tuple[SimState | None, list[SimState], list[int]]
Example:
# Initial call batch, completed = batcher.next_batch(None, None) # Process batch and check for convergence batch = process_batch(batch) convergence = check_convergence(batch) # Get next batch with converged states removed and new states added batch, completed = batcher.next_batch(batch, convergence)
Notes
When max_iterations is set, states that exceed this limit will be forcibly marked as converged regardless of their actual convergence state.
- restore_original_order(completed_states)[source]¶
Reorder completed states back to their original sequence.
Takes states that were completed in arbitrary order and restores them to the original order they were provided in. This is essential after using the hot-swapping strategy to ensure results correspond to input states.
- Parameters:
completed_states (list[SimState]) – Completed states to reorder. Each SimState contains simulation data with shape specific to its instance.
- Returns:
- States in their original order, with shape information
matching the original input states.
- Return type:
- Raises:
ValueError – If the number of completed states doesn’t match the number of completed indices.
Example:
# After processing with next_batch all_completed_states = [] # Process all states while batch is not None: batch = process_batch(batch) convergence = check_convergence(batch) batch, new_completed = batcher.next_batch(batch, convergence) all_completed_states.extend(new_completed) # Restore original order ordered_results = batcher.restore_original_order(all_completed_states)
Notes
This method should only be called after all states have been processed, or you will only get the subset of states that have completed so far.