measure_model_memory_forward

torch_sim.autobatching.measure_model_memory_forward(state, model)[source]

Measure peak GPU memory usage during a model’s forward pass.

Clears GPU cache, runs a forward pass with the provided state, and measures the maximum memory allocated during execution. This function helps determine the actual GPU memory requirements for processing a simulation state.

Parameters:
  • state (SimState) – Input state to pass to the model, with shape information determined by the specific SimState instance.

  • model (ModelInterface) – Model to measure memory usage for, implementing the ModelInterface protocol.

Returns:

Peak memory usage in gigabytes.

Return type:

float

Raises:

ValueError – If the model device is CPU, as memory estimation is only meaningful for GPU-based models.

Notes

This function performs a synchronization and cache clearing operation before measurement, which may impact performance if called frequently.