measure_model_memory_forward¶
- torch_sim.autobatching.measure_model_memory_forward(state, model)[source]¶
Measure peak GPU memory usage during a model’s forward pass.
Clears GPU cache, runs a forward pass with the provided state, and measures the maximum memory allocated during execution. This function helps determine the actual GPU memory requirements for processing a simulation state.
- Parameters:
state (SimState) – Input state to pass to the model, with shape information determined by the specific SimState instance.
model (ModelInterface) – Model to measure memory usage for, implementing the ModelInterface protocol.
- Returns:
Peak memory usage in gigabytes.
- Return type:
- Raises:
ValueError – If the model device is CPU, as memory estimation is only meaningful for GPU-based models.
Notes
This function performs a synchronization and cache clearing operation before measurement, which may impact performance if called frequently.