补充:标签平滑的实现

一、torch.gather()函数与交叉熵损失的计算

1.1 torch.gather()函数

 这个函数较难懂。这里讲一下,因为后面讲标签平滑的时候会用到。

 结果的形状与index的形状一样。比如上例中,index是一个2*2的张量,那么最后的结果就是2*2的张量。

首先令out[i][j] = input[i][j],然后因为dim是1,所以要把j换掉,换成index[i][j]。这样就得到最后的结果了。

out[0][0] = input[0][0], 把0换成index[0][0] = 0。所以out[0][0] = input[0][0] = 1;

out[0][1] = input[0][1], 把1换成index[0][1] = 0。所以out[0][1] = input[0][0] = 1;

out[1][0] = input[1][0], 把0换成index[1][0] = 1。所以out[1][0] = input[1][1] = 4;

out[1][1] = input[1][1], 把1换成index[1][1] = 0。所以out[1][1] = input[1][0] = 3;

因此最后的结果为:[[1,1],[4,3]]

1.2 利用torch.gather()函数计算交叉熵损失

交叉熵损失就不多少了,这里直接举例子。

例1:计算交叉熵损失函数

模型的输出和标签如下:

output = torch.tensor([[4.0, 5.0, 10.0], [1.0, 5.0, 4.0], [1.0, 15.0, 4.0]])  #共3个样本,3分类任务

label = torch.tensor([2, 1, 1], dtype=torch.int64)

# @file name  : test4.py
# @brief      : 求交叉熵损失
# @author     : liupc
# @date       : 2021/8/18

import torch
import torch.nn as nn

output = torch.tensor([[4.0, 5.0, 10.0], [1.0, 5.0, 4.0], [1.0, 15.0, 4.0]])
label = torch.tensor([2, 1, 1], dtype=torch.int64)

#方法一:使用交叉熵损失函数求
criterion1 = nn.CrossEntropyLoss()
loss1 = criterion1(output, label)
print("使用交叉熵损失函数计算的交叉熵损失为:{}".format(loss1))


#方法二:借助NLLLoss()函数
log_probs = torch.nn.functional.log_softmax(output, dim=-1)  # 实现 log(p_i)
print(log_probs)
criterion2 = nn.NLLLoss()
loss2 = criterion2(log_probs, label)
print("借助NLLLoss()函数计算的交叉熵损失为:{}".format(loss2))


#方法三:借助torch.gather()函数来求:
#log_probs = torch.nn.functional.log_softmax(output, dim=-1)  # 实现  log(p_i)
#下面这三句,其实就是实现的NLLLoss()的功能:把对应位置元素拿出来,求和,取平均
loss3 = -log_probs.gather(dim=-1, index=label.unsqueeze(1))
loss3 = loss3.squeeze(1)                          #H_pq就是每张图片的交叉熵损失,再求平均就是整体的交叉熵损失
loss3 = loss3.mean()
print("借助torch.gather()函数求的交叉熵损失为:{}".format(loss3))


结果:

前两种方法在《整理:熵、KL散度、交叉熵、nn.CrossEntropyLoss()、nn.LogSoftmax()、nn.NLLLoss()》都讲了。这里提供了第三种方法。

是借助torch.gather()实现的。这里就是使用gather()函数来代替了NLLLoss()的功能。讲解如下:

主要解释in[4]这一句。index是3*1的向量,所以res也是3*1的向量。dim=1,说明第1维被替换。

res[0][0] = input[0][?] = input[0][index[0][0]] = input[0][2] = 9.1745e-03

res[1][0] = input[1][?] = input[0][index[1][0]] = input[1][1] = 3.2656e-01

res[2][0] = input[2][?] = input[0][index[2][0]] = input[2][

  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值