safe_mask

torch_sim.transforms.safe_mask(mask, fn, operand, placeholder=0.0)[source]

Safely applies a function to masked values in a tensor.

This function applies the given function only to elements where the mask is True, avoiding potential numerical issues with masked-out values. Masked-out positions are filled with the placeholder value.

Parameters:
  • mask (Tensor) – Boolean tensor indicating which elements to process (True) or mask (False)

  • fn (ScriptFunction) – TorchScript function to apply to the masked elements

  • operand (Tensor) – Input tensor to apply the function to

  • placeholder (float) – Value to use for masked-out positions (default: 0.0)

Returns:

Result tensor where fn is applied to masked elements and

placeholder value is used for masked-out elements

Return type:

Tensor

Example

>>> x = torch.tensor([1.0, 2.0, -1.0])
>>> mask = torch.tensor([True, True, False])
>>> safe_mask(mask, torch.log, x)
tensor([0, 0.6931, 0])