pytorch的交叉熵损失函数是如何计算outputs和 labels之间的损失的?
对于一个分类问题的CNN模型,最后一层的代码一般如下:
nn.Linear(2048, num_classes)
然后计算一次迭代损失的代码一般如下:
loss_function = nn.CrossEntropyLoss()
outputs = net(images.to(device))
loss = loss_function(outputs, labels.to(device))
1、假设条件:
假设batsize=4,我们的任务是一个5分类的,即num_classes=5,labels=【人,猫,狗,兔,鸟】
但是我们做数据集的时候,一般用[0,1,2,3,4]来代替,即0代表人,1代表猫....,这些数字代表了onehot中元素为1的位置
2、output
那么神经网络的output形状为【4x5】4行5列。(以一张图片为一行,一共4行。每个图片可能的预测结果有5个,所以有5列)
由于网络直接输出的output不是概率,在计算损失时,会首先将output以行为单位计算Softmax,即被预测为每个类别的概率,使得它相加为1,,然后再取Log,
3、label
label会被变成onehot形式,假设这四个图片标签分别为:人,猫,狗,鸟
那么各自对应的onehot分别为[[1,0,0,0,0],
[0,1,0,0,0],
[0,0,1,0,0],
[0,0,0,0,1]],组成一个和output形状一样的tensor格式的标签
4、loss
最后计算onehot中元素为1的位置对output中的数字取logsoftmax后,求和,取反,再求均值即为最后的交叉熵损失
如果使用NLLLoss需要将网络的输出先经过LogSoftmax,CrossEntropyLoss = LogSoftmax 和 NLLLoss,