torch.nn.CrossEntropyLoss详解

本文主要讲述CrossEntropyLoss函数内部的计算逻辑,我觉得摆一大堆公式都不如直接上代码来的实在。

首先,我们先设置好两个变量,input与y_target。

importtorchimporttorch.nnasnnimportnumpyasnpa=np.arange(1,13).reshape(3,4)b=torch.from_numpy(a)input=b.float()print('input:\n',input)

可以看到input矩阵如下:

tensor([[1.,2.,3.,4.],[5.,6.,7.,8.],[9.,10.,11.,12.]])

然后设置y_target:

y_target=torch.tensor([1,2,3])print('y_target:\n',y_target)

这个不用打印大家也应该知道是什么样了。

input是一个【3 4】的矩阵,y-target是一个【1 3】的矩阵。input是预测值,代表有三个样本,四个类别。y-target代表三个样本的真实标签。

crossentropyloss=nn.CrossEntropyLoss(reduction='none')crossentropyloss_output=crossentropyloss(x_input,y_target)print('crossentropyloss_output:\n',crossentropyloss_output)

经过CrossEntropyLoss后,最终结果为:

crossentropyloss_output:tensor([2.4402,1.4402,0.4402])

下面我们来剖析它的计算过程。其实CrossEntropyLoss相当于softmax + log + nllloss

不信可以计算一遍:

softmax_func=nn.Softmax(dim=1)soft_output=softmax_func(input)print('soft_output:\n',soft_output)log_output=torch.log(soft_output)print('log_output:\n',log_output)nllloss_func=nn.NLLLoss(reduction='none')nllloss_output=nllloss_func(log_output,y_target)print('nllloss_output:\n',nlloss_output)

最终结果是一样的:

nllloss_output:tensor([2.4402,1.4402,0.4402])

softmax、log这两个函数应该都比较好理解。

下面主要讲解一下nllloss这个损失函数。

有了经过softmax与log的矩阵,我们叫它矩阵A。

tensor([[-3.4402,-2.4402,-1.4402,-0.4402],[-3.4402,-2.4402,-1.4402,-0.4402],[-3.4402,-2.4402,-1.4402,-0.4402]])

还有真实标签:

tensor([1, 2, 3])

y-target中第一个数字,代表第一个样本的真实标签,第一个数字是1,代表第一个样本的真实标签为1。所以取出矩阵A的第一行第二个数字-2.4402,并加上一个负号。之后也是这样,依次取出-1.4402、-0.4402。

所以有了最终结果:

crossentropyloss_output:
 tensor([2.4402, 1.4402, 0.4402])

reduction这个参数着重提一下,它一般有none、sum、mean等几个选项,none就是没有别的附加操作,sum就是把这个几个损失加和,mean就是把这几个损失求平均。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值