safe_mask¶
- torch_sim.transforms.safe_mask(mask, fn, operand, placeholder=0.0)[source]¶
Safely applies a function to masked values in a tensor.
This function applies the given function only to elements where the mask is True, avoiding potential numerical issues with masked-out values. Masked-out positions are filled with the placeholder value.
- Parameters:
mask (Tensor) – Boolean tensor indicating which elements to process (True) or mask (False)
fn (ScriptFunction) – TorchScript function to apply to the masked elements
operand (Tensor) – Input tensor to apply the function to
placeholder (float) – Value to use for masked-out positions (default: 0.0)
- Returns:
- Result tensor where fn is applied to masked elements and
placeholder value is used for masked-out elements
- Return type:
Example
>>> x = torch.tensor([1.0, 2.0, -1.0]) >>> mask = torch.tensor([True, True, False]) >>> safe_mask(mask, torch.log, x) tensor([0, 0.6931, 0])