import torch
class_num = 3
batch_size = 2
label = torch.LongTensor(batch_size, 1).random_() % class_num
print(label)
#tensor([[6],
# [0],
# [3],
# [2]])
xx=torch.zeros(batch_size, class_num).scatter_(1, label, 1)##原来0的位置 职位1
print(xx)
'''
tensor([[1],#####################代表 01 10位置变为1
[0]])####################代表 01 10位置变为1
tensor([[0., 1., 0.],
[1., 0., 0.]])
tensor([[0],##################代表 00 12
[2]])
tensor([[1., 0., 0.],
[0., 0., 1.]])
'''
#tensor([0.9000, 0.1000])
def _get_gt_mask(logits, target):