high_precision_sum¶
- torch_sim.transforms.high_precision_sum(x, dim=None, *, keepdim=False)[source]¶
Sums tensor elements over specified dimensions at 64-bit precision.
This function casts the input tensor to a higher precision type (64-bit), performs the summation, and then casts back to the original dtype. This helps prevent numerical instability issues that can occur when summing many numbers, especially with floating point values.
- Parameters:
- Returns:
Sum of elements cast back to original dtype
- Return type:
Example
>>> x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) >>> high_precision_sum(x) tensor(6., dtype=torch.float32)