estimate_max_memory_scaler¶
- torch_sim.autobatching.estimate_max_memory_scaler(model, state_list, metric_values, **kwargs)[source]¶
Estimate maximum memory scaling metric that fits in GPU memory.
Tests both minimum and maximum metric states to determine a safe upper bound for the memory scaling metric. This approach ensures the estimated value works for both small, dense systems and large, sparse systems.
- Parameters:
model (ModelInterface) – Model to test with, implementing the ModelInterface protocol.
state_list (list[SimState]) – States to test, each with shape information specific to the SimState instance.
metric_values (list[float]) – Corresponding metric values for each state, as calculated by calculate_memory_scaler().
**kwargs – Additional keyword arguments passed to determine_max_batch_size.
- Returns:
Maximum safe metric value that fits in GPU memory.
- Return type:
Example:
# Calculate metrics for a set of states metrics = [calculate_memory_scaler(state) for state in states] # Estimate maximum safe metric value max_metric = estimate_max_memory_scaler(model, states, metrics)
Notes
This function tests batch sizes with both the smallest and largest systems to find a conservative estimate that works across varying system sizes. The returned value will be the minimum of the two estimates.