一、torch.gather()函数与交叉熵损失的计算
1.1 torch.gather()函数
这个函数较难懂。这里讲一下,因为后面讲标签平滑的时候会用到。
结果的形状与index的形状一样。比如上例中,index是一个2*2的张量,那么最后的结果就是2*2的张量。
首先令out[i][j] = input[i][j],然后因为dim是1,所以要把j换掉,换成index[i][j]。这样就得到最后的结果了。
out[0][0] = input[0][0], 把0换成index[0][0] = 0。所以out[0][0] = input[0][0] = 1;
out[0][1] = input[0][1], 把1换成index[0][1] = 0。所以out[0][1] = input[0][0] = 1;
out[1][0] = input[1][0], 把0换成index[1][0] = 1。所以out[1][0] = input[1][1] = 4;
out[1][1] = input[1][1], 把1换成index[1][1] = 0。所以out[1][1] = input[1][0] = 3;
因此最后的结果为:[[1,1],[4,3]]
1.2 利用torch.gather()函数计算交叉熵损失
交叉熵损失就不多少了,这里直接举例子。
例1:计算交叉熵损失函数
模型的输出和标签如下:
output = torch.tensor([[4.0, 5.0, 10.0], [1.0, 5.0, 4.0], [1.0, 15.0, 4.0]]) #共3个样本,3分类任务
label = torch.tensor([2, 1, 1], dtype=torch.int64)
# @file name : test4.py
# @brief : 求交叉熵损失
# @author : liupc
# @date : 2021/8/18
import torch
import torch.nn as nn
output = torch.tensor([[4.0, 5.0, 10.0], [1.0, 5.0, 4.0], [1.0, 15.0, 4.0]])
label = torch.tensor([2, 1, 1], dtype=torch.int64)
#方法一:使用交叉熵损失函数求
criterion1 = nn.CrossEntropyLoss()
loss1 = criterion1(output, label)
print("使用交叉熵损失函数计算的交叉熵损失为:{}".format(loss1))
#方法二:借助NLLLoss()函数
log_probs = torch.nn.functional.log_softmax(output, dim=-1) # 实现 log(p_i)
print(log_probs)
criterion2 = nn.NLLLoss()
loss2 = criterion2(log_probs, label)
print("借助NLLLoss()函数计算的交叉熵损失为:{}".format(loss2))
#方法三:借助torch.gather()函数来求:
#log_probs = torch.nn.functional.log_softmax(output, dim=-1) # 实现 log(p_i)
#下面这三句,其实就是实现的NLLLoss()的功能:把对应位置元素拿出来,求和,取平均
loss3 = -log_probs.gather(dim=-1, index=label.unsqueeze(1))
loss3 = loss3.squeeze(1) #H_pq就是每张图片的交叉熵损失,再求平均就是整体的交叉熵损失
loss3 = loss3.mean()
print("借助torch.gather()函数求的交叉熵损失为:{}".format(loss3))
结果:
前两种方法在《整理:熵、KL散度、交叉熵、nn.CrossEntropyLoss()、nn.LogSoftmax()、nn.NLLLoss()》都讲了。这里提供了第三种方法。
是借助torch.gather()实现的。这里就是使用gather()函数来代替了NLLLoss()的功能。讲解如下:
主要解释in[4]这一句。index是3*1的向量,所以res也是3*1的向量。dim=1,说明第1维被替换。
res[0][0] = input[0][?] = input[0][index[0][0]] = input[0][2] = 9.1745e-03
res[1][0] = input[1][?] = input[0][index[1][0]] = input[1][1] = 3.2656e-01
res[2][0] = input[2][?] = input[0][index[2][0]] = input[2][