InFlightAutoBatcher

class torch_sim.autobatching.InFlightAutoBatcher(model, *, memory_scales_with='n_atoms_x_density', max_memory_scaler=None, max_atoms_to_try=500_000, memory_scaling_factor=1.6, return_indices=False, max_iterations=None, max_memory_padding=1.0)[source]

Bases: object

Batcher that dynamically swaps states based on convergence.

Optimizes GPU utilization by removing converged states from the batch and adding new states to process. This approach is ideal for iterative processes where different states may converge at different rates, such as geometry optimization.

To avoid a slow memory estimation step, set the max_memory_scaler to a known value.

In-flight auto-batcher diagram
Variables:
  • model (ModelInterface) – Model used for memory estimation and processing.

  • memory_scales_with (str) – Metric type used for memory estimation.

  • max_memory_scaler (float) – Maximum memory metric allowed per batch.

  • max_atoms_to_try (int) – Maximum number of atoms to try when estimating memory.

  • return_indices (bool) – Whether to return original indices with batches.

  • max_iterations (int | None) – Maximum number of iterations per state.

  • state_slices (list[SimState]) – Individual states to be batched.

  • memory_scalers (list[float]) – Memory scaling metrics for each state.

  • current_idx (list[int]) – Indices of states in the current batch.

  • completed_idx (list[int]) – Indices of states that have been processed.

  • completed_idx_og_order (list[int]) – Original indices of completed states.

  • current_scalers (list[float]) – Memory metrics for states in current batch.

  • swap_attempts (dict[int, int]) – Count of iterations for each state.

Parameters:
  • model (ModelInterface)

  • memory_scales_with (Literal['n_atoms_x_density', 'n_atoms'])

  • max_memory_scaler (float | None)

  • max_atoms_to_try (int)

  • memory_scaling_factor (float)

  • return_indices (bool)

  • max_iterations (int | None)

  • max_memory_padding (float)

Example:

# Create a hot-swapping batcher
batcher = InFlightAutoBatcher(
    model=lj_model, memory_scales_with="n_atoms", max_memory_scaler=1000.0
)

# Load states and process them with convergence checking
batcher.load_states(states)
batch, completed_states = batcher.next_batch(None, None)

while batch is not None:
    # Process the batch
    batch = process_batch(batch)

    # Check convergence
    convergence = check_convergence(batch)

    # Get next batch, with converged states swapped out
    batch, new_completed = batcher.next_batch(batch, convergence)
    completed_states.extend(new_completed)

# Restore original order
ordered_results = batcher.restore_original_order(completed_states)
load_states(states)[source]

Load new states into the batcher.

Processes the input states, computes memory scaling metrics for each, and prepares them for dynamic batching based on convergence criteria. Unlike BinningAutoBatcher, this doesn’t create fixed batches upfront.

Parameters:

states (list[SimState] | Iterator[SimState] | SimState) – Collection of states to batch. Can be a list of individual SimState objects, an iterator yielding SimState objects, or a single batched SimState that will be split into individual states. Each SimState has shape information specific to its instance.

Raises:

ValueError – If any individual state has a memory scaling metric greater than the maximum allowed value.

Return type:

None

Example:

# Load individual states
batcher.load_states([state1, state2, state3])

# Or load a batched state that will be split
batcher.load_states(batched_state)

# Or load states from an iterator
batcher.load_states(state_generator())

Notes

This method resets the current state indices and completed state tracking, so any ongoing processing will be restarted when this method is called.

next_batch(updated_state, convergence_tensor)[source]

Get the next batch of states based on convergence.

Removes converged states from the batch, adds new states if possible, and returns both the updated batch and the completed states. This method implements the core dynamic batching strategy of the InFlightAutoBatcher.

Parameters:
  • updated_state (SimState | None) – Current state after processing, or None for the first call. Contains shape information specific to the SimState instance.

  • convergence_tensor (Tensor | None) – Boolean tensor with shape [n_batches] indicating which states have converged (True) or not (False). Should be None only for the first call.

Returns:

tuple[SimState | None, list[SimState]] | tuple[SimState | None,

list[SimState], list[int]]: - If return_indices is False: Tuple of (next_batch, completed_states)

where next_batch is a SimState or None if all states are processed, and completed_states is a list of SimState objects.

  • If return_indices is True: Tuple of (next_batch, completed_states, indices) where indices are the current batch’s positions.

Raises:

AssertionError – If convergence_tensor doesn’t match the expected shape or if other validation checks fail.

Return type:

tuple[SimState | None, list[SimState]] | tuple[SimState | None, list[SimState], list[int]]

Example:

# Initial call
batch, completed = batcher.next_batch(None, None)

# Process batch and check for convergence
batch = process_batch(batch)
convergence = check_convergence(batch)

# Get next batch with converged states removed and new states added
batch, completed = batcher.next_batch(batch, convergence)

Notes

When max_iterations is set, states that exceed this limit will be forcibly marked as converged regardless of their actual convergence state.

restore_original_order(completed_states)[source]

Reorder completed states back to their original sequence.

Takes states that were completed in arbitrary order and restores them to the original order they were provided in. This is essential after using the hot-swapping strategy to ensure results correspond to input states.

Parameters:

completed_states (list[SimState]) – Completed states to reorder. Each SimState contains simulation data with shape specific to its instance.

Returns:

States in their original order, with shape information

matching the original input states.

Return type:

list[SimState]

Raises:

ValueError – If the number of completed states doesn’t match the number of completed indices.

Example:

# After processing with next_batch
all_completed_states = []

# Process all states
while batch is not None:
    batch = process_batch(batch)
    convergence = check_convergence(batch)
    batch, new_completed = batcher.next_batch(batch, convergence)
    all_completed_states.extend(new_completed)

# Restore original order
ordered_results = batcher.restore_original_order(all_completed_states)

Notes

This method should only be called after all states have been processed, or you will only get the subset of states that have completed so far.