for j, data in enumerate(test_dataloader, 1):
inputs, labels = data
print(labels.type()) # torch.cuda.ByteTensor
labels = labels.long() # longTensor
print(labels)
onehot_labels = torch.nn.functional.one_hot(labels, num_classes=256)
print(onehot_labels[0][0][5])
label = torch.argmax(onehot_labels, -1) #从onehot转换回去
print(label)
Pytorch标签转化为onehot形式,再转换回去
最新推荐文章于 2023-04-25 12:41:34 发布