one_hot = torch.nn.functional.one_hot(torch.arange(3), num_classes=5)
print--->tensor([[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 0]])
one_hot = torch.nn.functional.one_hot(torch.LongTensor([1,3,4]), num_classes=5)
print--->tensor([[0, 1, 0, 0, 0],
[0, 0, 0, 1, 0],
[0, 0, 0, 0, 1]])
Parameters
tensor (LongTensor) – class values of any shape.
num_classes (int) – Total number of classes. If set to -1, the number of classes will be inferred as one greater than the largest class value in the input tensor.
Returns
LongTensor that has one more dimension with 1 values at the index of last dimension indicated by the input, and 0 everywhere else.