交叉熵损失,softmax函数和 torch.nn.CrossEntropyLoss()中文

背景

多分类问题里(单对象单标签),一般问题的setup都是一个输入,然后对应的输出是一个vector,这个vector的长度等于总共类别的个数。输入进入到训练好的网络里,predicted class就是输出层里值最大的那个entry对应的标签。

交叉熵在多分类神经网络训练中用的最多的loss function(损失函数)。 举一个很简单的例子,我们有一个三分类问题,对于一个input \(x\),神经网络最后一层的output (\(y\))是一个\((3 \times 1)\)的向量。然后这个\(x\)对应的ground-truth(\(y^{'}\) )也是一个\((3 \times 1)\)的向量。

交叉熵

来举一个例子,让三个类别分别是类别0,1和2。这里让input \(x\)属于类别0。所以ground-truth(\(y^{'}\) ) 就等于\((1,0,0)\), 让网络的预测输出(\(y\))等于\((3,1,-3)\)
\[ y = (3,1,-3), \space y^{'} = (1,0,0) \]
交叉熵损失的定义如下公式所示(在上面的列子里,i是从0到2的):
\[ H (y,y^{'})= -\sum_i {y_i^{'}}\log(softmax(y_i)) \]

Softmax

softmax的计算可以在下图找到。注意在图里,softmax的输入\((3,1,-3)\) 是神经网络最后一个fc层的输出(\(y\))。\(y\)经过softmax层之后,就变成了\(softmax(y)=(0.88,0.12,0)\)\(y\)的每一个entry可以看作每一个class的预测得分,那么\(softmax(y)\)的每一个entry就是每一个class的预测概率。
\[ H (y,y^{'})= -{y_j^{'}}\log(softmax(y_j)) \]
对于上面的列子,当前\(x\)的分类loss就是\(H(y,y^{'})=-1\times \log(0.88)=0.12\) (注意,这里\(\log\)的base是\(e\))

softmax常用于多分类过程中,它将多个神经元的输出,归一化到\((0, 1)\) 区间内,因此Softmax的输出可以看成概率,从而来进行多分类。

1475891-20181030204227478-1135460200.png

nn.CrossEntropyLoss() in Pytorch

其实归根结底,交叉熵损失的计算只需要一个term。这个term就是在softmax输出层中找到ground-truth里正确标签对应的那个entry \(j\) ,也就是(\(\log(softmax(y_j))\))。(当然咯,在计算\(softmax(y_j)\)的时候,我们是需要y里所有的term的值的。)
\[ H (y,y^{'})= -{y_j^{'}}\log(softmax(y_j)) \]
因为entry \(j\)对应的是ground-truth里正确的class。只有在\(i=j\)的时候才\(y^{'}_i = 1\),其他时候都等于0。

在下面的代码里,我们把python中torch.nn.CrossEntropyLoss() 的计算结果和用公式计算出的交叉熵结果进行比较. 结果显示,torch.nn.CrossEntropyLoss()的input只需要是网络fc层的输出\(y\), 在torch.nn.CrossEntropyLoss()里它会自己把\(y\) 转化成\(softmax(y)\) 然后再进行交叉熵loss的运算.

所以当我们用PyTorch搭建分类网络的时候,不需要再在最后一个fc层后再手动添加一个softmax层。

注意,在用PyTorch做分类问题的时候,在网络搭建时(假设全连接层的output是y),在之后加一个 y = torch.nn.functional.log_softmax (y),并在训练时,用torch.nn.functional.nll_loss(y, labels)。这样达到的效果和不用log_softmax层,并用torch.nn.CrossEntropyLoss(y,labels)做损失函数是一模一样的。

import torch
import torch.nn as nn
import math

output = torch.randn(1, 5, requires_grad = True) #假设是网络的最后一层,5分类
label = torch.empty(1, dtype=torch.long).random_(5) # 0 - 4, 任意选取一个分类

print ('Network Output is: ', output)
print ('Ground Truth Label is: ', label)

score = output [0,label.item()].item() # label对应的class的logits(得分)
print ('Score for the ground truth class = ', label)

first = - score
second = 0
for i in range(5):
    second += math.exp(output[0,i])
second = math.log(second)

loss = first + second
print ('-' * 20)
print ('my loss = ', loss)

loss = nn.CrossEntropyLoss()
print ('pytorch loss = ', loss(output, label))

转载于:https://www.cnblogs.com/fledlingbird/p/10718096.html

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值