to_one_hot

torch_sim.models.mace.to_one_hot(indices, num_classes, dtype)[source]

Generates one-hot encoding from indices.

NOTE: this is a modified version of the to_one_hot function in mace.tools, consider using upstream version if possible after https://github.com/ACEsuit/mace/pull/903/ is merged.

Parameters:
  • indices (Tensor) – A tensor of shape (N x 1) containing class indices.

  • num_classes (int) – An integer specifying the total number of classes.

  • dtype (dtype) – The desired data type of the output tensor.

Returns:

A tensor of shape (N x num_classes) containing the

one-hot encodings.

Return type:

Tensor