calculate_memory_scaler¶
- torch_sim.autobatching.calculate_memory_scaler(state, memory_scales_with='n_atoms_x_density')[source]¶
Calculate a metric that estimates memory requirements for a state.
Provides different scaling metrics that correlate with memory usage. Models with radial neighbor cutoffs generally scale with “n_atoms_x_density”, while models with a fixed number of neighbors scale with “n_atoms”. The choice of metric can significantly impact the accuracy of memory requirement estimations for different types of simulation systems.
- Parameters:
state (SimState) – State to calculate metric for, with shape information specific to the SimState instance.
memory_scales_with ("n_atoms_x_density" | "n_atoms") – Type of metric to use. “n_atoms” uses only atom count and is suitable for models that have a fixed number of neighbors. “n_atoms_x_density” uses atom count multiplied by number density and is better for models with radial cutoffs Defaults to “n_atoms_x_density”.
- Returns:
Calculated metric value.
- Return type:
- Raises:
ValueError – If state has multiple batches or if an invalid metric type is provided.
Example:
# Calculate memory scaling factor based on atom count metric = calculate_memory_scaler(state, memory_scales_with="n_atoms") # Calculate memory scaling factor based on atom count and density metric = calculate_memory_scaler(state, memory_scales_with="n_atoms_x_density")