estimate_max_memory_scaler

torch_sim.autobatching.estimate_max_memory_scaler(model, state_list, metric_values, **kwargs)[source]

Estimate maximum memory scaling metric that fits in GPU memory.

Tests both minimum and maximum metric states to determine a safe upper bound for the memory scaling metric. This approach ensures the estimated value works for both small, dense systems and large, sparse systems.

Parameters:
  • model (ModelInterface) – Model to test with, implementing the ModelInterface protocol.

  • state_list (list[SimState]) – States to test, each with shape information specific to the SimState instance.

  • metric_values (list[float]) – Corresponding metric values for each state, as calculated by calculate_memory_scaler().

  • **kwargs – Additional keyword arguments passed to determine_max_batch_size.

Returns:

Maximum safe metric value that fits in GPU memory.

Return type:

float

Example:

# Calculate metrics for a set of states
metrics = [calculate_memory_scaler(state) for state in states]

# Estimate maximum safe metric value
max_metric = estimate_max_memory_scaler(model, states, metrics)

Notes

This function tests batch sizes with both the smallest and largest systems to find a conservative estimate that works across varying system sizes. The returned value will be the minimum of the two estimates.