Pytorch手撸交叉熵CrossEntropyLoss并修改one-hot输入

1 前言

楼主最近要修改一下one-hot然后送入交叉熵中,由于pytorch的torch.nn.CrossEntropyLoss()已经封装好了one-hot,所以需要自定义重写交叉熵,这里主要是多级交叉熵,而不是二分类交叉熵。

2 交叉熵的实现过程

首先看多级交叉熵的计算过程:

L = − 1 N ∑ i N ∑ c = 1 M y i c l o g ( p i c ) L = -\frac{1}{N}\displaystyle\sum_i^N\displaystyle\sum_{c=1}^My_{ic}log(p_{ic}) L=N1iNc=1Myiclog(pic)
其中

  • M M M表示多分类类别的数量;
  • y i c y_{ic} yic只为0和1,表示如果该样本 i i i 的类别为 c c c,则为1否则为0
  • p i c p_{ic} pic神经网络输出的样本 i i i 类别为 c c c 的概率

用程序来描述如下:
首先是 p i c p_{ic} pic,可以看出他就是神经网络输出后加入softmax形成的概率值,
在这里插入图片描述

然后是 y i c y_{ic} yic ,其含义是类别一致为1,不一致为0,很显然符合one-hot的定义,所以pytorch在交叉熵内部封装了one-hot来实现这一步。(注意,这里的y是一个矩阵,不仅仅只是 y i c y_{ic} yic,而应该是 y i y_i yi,表示一整个样本的真实标签)
在这里插入图片描述
连起来, ∑ c = 1 M y i c l o g ( p i c ) \displaystyle\sum_{c=1}^My_{ic}log(p_{ic}) c=1Myiclog(pic)可以表示为:
在这里插入图片描述
然后进行最后一步,对所有样本取平均,即 − 1 N ∑ i N -\frac{1}{N}\displaystyle\sum_i^N N1iN部分,由于这里我们只有一个样本,所以还是一样的结果,只是符号有变化。

在这里插入图片描述

3 完整的自定义交叉熵

所以由以上过程,我们自己实现交叉熵的过程为:

class Our_CrossEntropy(torch.nn.Module):

    def __init__(self):
        super(Our_CrossEntropy,self).__init__()
    
    def forward(self, x ,y):
        P_i = torch.nn.functional.softmax(x, dim=1)
        y = torch.nn.functional.one_hot(y)
        loss = y*torch.log(P_i + 0.0000001)
        loss = -torch.mean(torch.sum(loss, dim=1),dim = 0)
        return loss

验证一下:
在这里插入图片描述
没毛病~

torch.nn.CrossEntropyLoss()输出一致

4 使用自己one-hot的交叉熵

回到最初的问题,我需要使用自己修改之后的one-hot,所以自定义交叉熵后,去掉里面的one-hot部分,将one-hot写在函数外面进行修改再导入即可:

class one_hot_CrossEntropy(torch.nn.Module):

    def __init__(self):
        super(one_hot_CrossEntropy,self).__init__()
    
    def forward(self, x ,y):
        P_i = torch.nn.functional.softmax(x, dim=1)
        loss = y*torch.log(P_i + 0.0000001)
        loss = -torch.mean(torch.sum(loss,dim=1),dim = 0)
        return loss

验证一下:
在这里插入图片描述
没毛病~

修改一下one-hot,让他不是1:
在这里插入图片描述

成功~

评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

锌a

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值