理解与使用torch的CrossEntropy Loss

交叉熵误差函数形式不难,但是使用起来却容易出错,这里记录一下自己的理解。

使用形式

torch提供了类和函数两种形式来使用交叉熵误差函数,其中类的形式在源码中调用了函数形式,所以使用哪一种都可以:

importtorch.nnasnnimporttorch.nn.functionalasF# 函数形式:out=F.cross_entropy(pred,target)# 类的形式:func=nn.CrossEntropy()# 括号不要掉,定义一个对象out=func(pred,target)

底层计算逻辑

假设输入的是一个shape=[3, 5]的张量,3表示batch_size(N),5表示类别数(C),这一步可以是神经网络的最后一层全连接层的输出。先定义好输入和标签:

a=torch.randn(3,5)print(a)----------------tensor([[-0.6311,0.4457,0.5320,0.5855,-0.2480],[0.5942,0.2531,0.7998,-0.4247,-0.2831],[1.5852,1.9026,0.7302,-2.0489,-0.5581]])target=torch.tensor([2,4,3],dtype=torch.long)

对于这样的输出容易理解:3个样本,每一个样本分别属于5种类别的“得分”。因为概率值一定是[0, 1]的数,所以需要使用softmax函数来将“得分”转化为“概率值”:

input=a.softmax(dim=1)print(input)----------------------tensor([[0.0835,0.2451,0.2672,0.2818,0.1225],[0.2691,0.1913,0.3305,0.0971,0.1119],[0.3398,0.4668,0.1445,0.0090,0.0399]])

这里的dim参数为什么=1呢?因为我们需要知道的是,每一个样本属于C种类别的概率是多少,所以样本与样本之间是独立的,而这里每一行是一个样本,所以需要沿着横轴同时将5个数据映射到softmax函数。

随后,就可以计算交叉熵误差,公式:

�=−∑�=1��������ln⁡������

代码:

-(torch.log(input[0][target[0]])+torch.log(input[1][target[1]])+torch.log(input[2][target[2]]))/3----------------------tensor(2.7411)

这里有几个问题:

代码这样写看上去target中的值只是索引?

target值确实是索引,只是这里的处理逻辑和torch的处理逻辑不同,torch先将target转换成shape同input的one-hot编码形式:

importtorchimporttorch.nn.functionalasFdefto_one_hot(tensor,num_clsses):asserttensor.dim()<=1,"[Error] tensor.dim >= 1"one_hot=torch.zeros(len(tensor),num_clsses)idx=range(len(tensor))one_hot[idx,tensor.reshape(-1,)]=1returnone_hotbatch_size=10num_classes=5target=torch.randint(num_classes,size=(batch_size,))# 非delta标签分布target_delta=to_one_hot(target,num_classes)

然后将target和input进行元素对应相乘,然后使用log函数。这里的处理方式就直接分别取出三个样本的标签值的索引,三个样本的标签索引分别是2, 4, 3,所以应该取出这三个值:

然后映射到log函数中:

torch.log

得到三个负值,三个负值相加取均值,得到一个batch的平均损失值(如果从input中选出的三个值越大,那么映射到log中值的绝对值越小,也就是损失越小)。

2. 取均值还是累加和?

torch源码默认取的是均值,可以传入reduction参数来修改:

defcross_entropy(input:Tensor,target:Tensor,weight:Optional[Tensor]=None,size_average:Optional[bool]=None,ignore_index:int=-100,reduce:Optional[bool]=None,reduction:str="mean",# 默认取均值)->Tensor:

基于上述案例使用封装函数

注意这里的a是softmax操作前的结果,这是由于函数内部有softmax操作,具体见官方文档:

F.cross_entropy(a,target)----------------------tensor(2.7411)

binary cross entropy loss

交叉熵误差函数还有一种特殊情况,即二值交叉熵误差。比如一个正负例预测问题中,希望网络输出结果为1或0,网络输出的是一个值,需要首先将该值隐射到sigmoid函数中,以限制值的范围到(0, 1),然后再和标签值进行计算。其中pred的shape=(N, ),target的shape=(N, )。使用方式如下所示:

defloss_fn(pred,target):pred=pred.reshape(-1)target=target.reshape(-1)critical=nn.functional.binary_cross_entropy_with_logits(pred,target)returncritical

这里使用了binary_cross_entropy_with_logits,该函数自带sigmoid操作,所以网络输出值直接传进来即可,不需要我们自己手动sigmoid。

另外,如果我们的输入值趋近于0或1,那么经过log函数得到的结果趋近于无穷大,这种情况下反向求导会无法继续计算下去,pytorch的解决方案是将经过Log后的结果限制在[-100, 100]内的范围,这样就可以保证梯度存在且可以继续计算。

Reference:

(62条消息) pytorch语义分割中CrossEntropyLoss()损失函数的理解与分析_北斗星辰001的博客-CSDN博客

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值