手动实现torch.nn.CrossEntropyLoss()
损失函数手动实现
def cross_entropy_loss(y_pred , y_true):
one_hot = F.one_hot(y_true).float() # 对真实labels进行one_hot编码
softmax = torch.exp(y_pred) / torch.sum(torch.exp(y_pred), dim=1).reshape(-1, 1)#将预测结果进行softmax
logsoftmax = torch.log(softmax)
nllloss = -torch.sum(one_hot * logsoftmax) / y_true.shape[0]
return nllloss
测试手写损失函数的有效性
相同相同预测张量与真实值输入手写损失函数与nn.CrossEntropyLoss(),输出结果一致,则证明手写损失函数有效
假设该分类任务有7个类别
logits表示预测输出,尺寸为[batch_size,num_labels],即[7,7]
labels表示真实标签,尺寸为[batch_size],即[7]
注意:因为总类别数(num_labels)为7,所以labels必须包含0-6所有,否则one_hot编码后尺寸不对,one_hot * logsoftmax运算会报错
logits = torch.tensor([[-0.1378, 0.3560, -0.1881, -0.1667, -0.1741, 0.3571, -0.2159],
[-0.1421, 0.3805, -0.1834, -0.1769, -0.2232, 0.3823, -0.2155],
[-0.0975, 0.4085, -0.1694, -0.1913, -0.2099, 0.3725, -0.2611],
[-0.1729, 0.4593, -0.1960, -0.2224, -0.2111, 0.4147, -0.3108],
[-0.1913, 0.4942, -0.2568, -0.1854, -0.2764, 0.4796, -0.3428],
[-0.4058, 0.9523, -0.5067, -0.4886, -0.5095, 0.7811, -0.5855],
[-0.3806, 1.1320, -0.5704, -0.5436, -0.6376, 0.9648, -0.6508]
])
labels=torch.tensor([0,1,2,3,4,5,6])
loss_value = cross_entropy_loss(logits,labels)
print("自定义函数实现",loss_value)
criterion = nn.CrossEntropyLoss()
loss = criterion(logits, labels)
print("函数实现",loss)