RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 ‘target’ in call to _thnn_nll_loss_forward
- pytorch用**nn.CrossEntropyLoss()**计算损失的时候出错
loss = loss_fn(outputs, targets)
File "D:\anaconda3\envs\py3\lib\site-packages\torch\nn\modules\module.py", line 722, in _call_impl
result = self.forward(*input, **kwargs)
File "D:\anaconda3\envs\py3\lib\site-packages\torch\nn\modules\loss.py", line 948, in forward
ignore_index=self.ignore_index, reduction=self.reduction)
File "D:\anaconda3\envs\py3\lib\site-packages\torch\nn\functional.py", line 2422, in cross_entropy
return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
File "D:\anaconda3\envs\py3\lib\site-packages\torch\nn\functional.py", line 2218, in nll_loss
ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'target' in call to _thnn_nll_loss_forward
- 输出此时的outputs and targets
outputs: tensor([[-3.9258, -1.4330, -5.0279, 20.8584, -6.3585, 8.7937, -0.6233, -5.7761,
0.5244, -7.1076]], device='cuda:0')
targets: tensor([[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.]], device='cuda:0')
-
原因
在计算交叉熵损失函数的时候不需要手动把target转化为one-hot编码的格式,注:one-hot编码的格式是[0,0,0,1,0]这种,而计算交叉熵时应该提供的时类别索引,在计算时会自动转化为one-hot类型,如[0,0,0,1,0]在这里应该为3。
这个在创建dataSet的__getitem__函数中修改返回的label就可以
class myTestDataset(torch.utils.data.Dataset):
def __init__(self,transform=None):
images = np.load('data_src.npy')
labels = np.load('label_src.npy')
self.images = [Image.fromarray(x) for x in images]
self.labels = labels / labels.sum(axis=1, keepdims=True)
self.labels = self.labels.astype(np.float32)
self.transform = transform
def __getitem__(self, index):
image = self.images[index]
label = self.labels[index]
label = np.argmax(label) #使用CrossEntropyLoss时把one-hot编码修改为类别索引
if self.transform:
image = self.transform(image)
return image,label
def __len__(self):
return len(self.images)
不过需要注意targets计算时应该是int64或者long类型,可以在计算时用loss = loss_fn(outputs, targets.long())