high_precision_sum

torch_sim.transforms.high_precision_sum(x, dim=None, *, keepdim=False)[source]

Sums tensor elements over specified dimensions at 64-bit precision.

This function casts the input tensor to a higher precision type (64-bit), performs the summation, and then casts back to the original dtype. This helps prevent numerical instability issues that can occur when summing many numbers, especially with floating point values.

Parameters:
  • x (Tensor) – Input tensor to sum

  • dim (int | Iterable[int] | None) – Dimension(s) along which to sum. If None, sum over all dimensions

  • keepdim (bool) – If True, retains reduced dimensions with length 1

Returns:

Sum of elements cast back to original dtype

Return type:

Tensor

Example

>>> x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
>>> high_precision_sum(x)
tensor(6., dtype=torch.float32)