calculate_memory_scaler

torch_sim.autobatching.calculate_memory_scaler(state, memory_scales_with='n_atoms_x_density')[source]

Calculate a metric that estimates memory requirements for a state.

Provides different scaling metrics that correlate with memory usage. Models with radial neighbor cutoffs generally scale with “n_atoms_x_density”, while models with a fixed number of neighbors scale with “n_atoms”. The choice of metric can significantly impact the accuracy of memory requirement estimations for different types of simulation systems.

Parameters:
  • state (SimState) – State to calculate metric for, with shape information specific to the SimState instance.

  • memory_scales_with ("n_atoms_x_density" | "n_atoms") – Type of metric to use. “n_atoms” uses only atom count and is suitable for models that have a fixed number of neighbors. “n_atoms_x_density” uses atom count multiplied by number density and is better for models with radial cutoffs Defaults to “n_atoms_x_density”.

Returns:

Calculated metric value.

Return type:

float

Raises:

ValueError – If state has multiple batches or if an invalid metric type is provided.

Example:

# Calculate memory scaling factor based on atom count
metric = calculate_memory_scaler(state, memory_scales_with="n_atoms")

# Calculate memory scaling factor based on atom count and density
metric = calculate_memory_scaler(state, memory_scales_with="n_atoms_x_density")