自己记录用的:
batch_label_ce = torch.empty([len(batch_data)], dtype=torch.long)
for i in range(len(batch_data)):
if batch_data[i,:] == [1., 0., 0.]:
batch_label_ce[i] = 0
elif batch_data[i,:] == [0., 1., 0.]:
batch_label_ce[i] = 1
elif batch_data[i,:] == [0., 0., 1.]:
batch_label_ce[i] = 2
之前的时候写的是
1/未加小数点。
2/未注意行列关系。
3/未检查后面使用否有修改。