def label_encoder(num_classes, label):
new_tensor = torch.zeros(2,num_classes, 3, 2)
for i in range(num_classes):
# 将原始标签张量中值为0的位置置为1
new_tensor[:,i] = (label == i).float()
return new_tensor
one_hot编码
最新推荐文章于 2024-07-12 23:11:20 发布