【Pytorch学习笔记】4.细讲Pytorch的gather函数是什么——从Softmax回归中交叉熵损失函数定义的角度讲述

本文详细解析了torch.gather函数在Softmax回归中的应用,通过实际例子演示如何根据索引矩阵从数据矩阵中获取特定值,进而构建交叉熵损失函数。重点在于从定义角度解释其在计算概率分布差异中的作用,并阐述了其内存效率和矩阵运算优势。
摘要由CSDN通过智能技术生成

本文接着慢慢磨pytorch基础。本来是想记录一下心得,结果码着码着又讲成了story。

gather函数:原始数据矩阵 根据索引矩阵 取到对应值矩阵

我们在学习Softmax回归从零实现的时候,需要定义一个交叉熵损失函数。我们会使用torch.gather函数的方法取原始数据矩阵中对应位置的值,接着再取log等处理。

可以先看看的例子

可以先囫囵吞枣地看一下gather函数的例子:

# 变量y_hat是2个样本在3个类别的预测概率,变量y是这2个样本的标签类别。
import torch
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
# y是一个索引矩阵,包含了每个样本的正确类别的位置。
# 比如第1个样本是第1类,第2个样本是第3类。0即第1个,2即第2个。
y = torch.LongTensor([0, 2])
y_hat.gather(1, y.view(-1, 1))

# 输出:
tensor([[0.1000],
        [0.5000]])

可以看到y是一个索引矩阵,包含了每个样本的正确类别的位置。比如第1个样本是第1类,第2个样本是第3类(0是第1个,2是第3个)。
原始数据矩阵y_hat使用gather传入y后,成功取到了两个样本的预测概率分布 各自对应正确类别下的 概率。
所以这里已经可以理解:gather函数是 原始数据矩阵 根据 索引矩阵 取到 对应值矩阵 的一个过程。

这样我们就可以定义交叉熵损失函数:

def cross_entropy(y_hat, y):
    return - torch.log(y_hat.gather(1, y.view(-1, 1)))

官方文档解释

官方文档中对gather函数的描述非常简洁,即按设定的维度方向,按该维度上的索引值取值。
在这里插入图片描述
给的例子也很直白。但是直观上看这是一个逻辑上非常绕的函数,按维度方向根据索引取值也是很难想象的。

我的思考是为什么我们要用这么别扭的方法取值呢?定义交叉熵损失函数为啥要用gather呢?

从交叉熵损失函数定义的角度理解gather函数的使用

  1. 在Softmax回归中,我们知道要使用交叉熵损失函数来计算两个概率分布之间的差异。
    还是一开始的例子,我们有 2个样本3类别的预测概率分布y_hat 和实际的概率分布y,如图:
    在这里插入图片描述

  2. 根据交叉熵损失函数:
    在这里插入图片描述
    对于softmax回归来说,真实的概率分布只有在正确的类别值为1,所以损失函数最后可以化简为:
    − log ⁡ y ^ p ( i ) -\log \hat{y}_{p}^{(i)} logy^p(i)。即在预测概率分布中取到第p个值,而p就是索引值,对应正确的类别。

  3. 对于一个batch的数据来说,每一行是一个样本的各类别概率分布,每一行都会有一个正确分类,位于第 p ( i ) p^{(i)} p(i)列。这个第 p ( i ) p^{(i)} p(i)列也是真实分布y每一行中1的位置。所以我们只要知道每个样本正确的类别所在列的标号就行了。
    即对于实际概率分布y矩阵:
    在这里插入图片描述我们只需要把y矩阵转变为索引矩阵
    故对 y_hat 的每个样本取索引值就得到了对应类别的预测概率。之后就可以进行取log等操作来定义损失函数了。

  4. 这样,根据一个索引矩阵对一个原始数据矩阵的行(或列)取索引对应的值,就是gather函数的具体做的事了。而交叉熵损失函数的定义正好符合了这样的数据特征,所以就能正好使用gather函数了。
    (这里的索引矩阵y在实际例子中就是一个batch数据X的对应label y。)

这样做有什么好处呢?个人觉得:
6. 将真实概率矩阵(抑或叫One-Hot编码矩阵)缩减为索引矩阵可以大大减少内存开销。
7. 索引值矩阵的维数没有变化,依然保留了矩阵并行运算效率高的优点。

torch.gather()

讲了半天交叉熵损失函数,最后讲讲gather函数的使用。
对于初学者来说,最重要的参数就是dim。在开头定义交叉熵损失函数的时候y_hat.gather(1, y.view(-1, 1)),参数dim=1,以行为方向进行索引。
直接用二维Tensor举例:

import torch
src = torch.arange(1, 16).reshape(5, 3)
"""
src:
tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [10, 11, 12],
        [13, 14, 15]])
"""

# 定义两个索引矩阵
index1 = torch.tensor([[0, 1, 2], [2, 3, 4], [0, 2, 4]])
index2 = torch.tensor([[1, 2, 0, 2, 1], [1, 2, 1, 0, 0]]).t()
"""
index1:
tensor([[0, 1, 2],
        [2, 3, 4],
        [0, 2, 4]])
index2:
tensor([[1, 1],
        [2, 2],
        [0, 1],
        [2, 0],
        [1, 0]])
"""

# axis=0时,
output1 = src.gather(dim=0, index=index1)
print(output1)
"""
输出:
tensor([[ 1,  5,  9],
        [ 7, 11, 15],
        [ 1,  8, 15]])
"""

# axis=1
output2 = src.gather(dim=1, index=index2)
print(output2)
"""
输出:
tensor([[ 2,  2],
        [ 6,  6],
        [ 7,  8],
        [12, 10],
        [14, 13]])
"""

当一个数据矩阵用一个索引矩阵在某个维度方向上取值,那个方向上的所有值 按 索引矩阵中那个方向的索引值 来取值。

我们直接看下图就更加清晰明了(原图中Dim以1、2来说明,在pytorch中即为0、1):
指定了一个方向以后,index以指定方向一刀切下去,对这个方向的数字取值
指定了一个方向以后,index以指定方向一刀切下去,对这个方向的数字取值。
index和src的维度必须要一致,且除去切的那一维,其余维度shape也一致(扩展到三维也是这样)。

  • 7
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
以下是基于PyTorch的ArcFace人脸识别系统包含ArcFace函数的models.py文件的代码: ```python import torch import torch.nn as nn import torch.nn.functional as F class ArcFace(nn.Module): def __init__(self, embedding_size, num_classes, margin=0.5, scale=64): super().__init__() self.embedding_size = embedding_size self.num_classes = num_classes self.margin = margin self.scale = scale self.weight = nn.Parameter(torch.FloatTensor(num_classes, embedding_size)) nn.init.xavier_uniform_(self.weight) def forward(self, embeddings, labels): # normalize input embeddings embeddings = F.normalize(embeddings) # normalize weights weights = F.normalize(self.weight) # gather the correct weight for each label cosine = F.linear(embeddings, weights) logits = self.scale*cosine # add margin to the correct logit mask = F.one_hot(labels, self.num_classes) logits[mask.bool()] -= self.margin # softmax cross-entropy loss loss = F.cross_entropy(logits, labels) return loss ``` 该代码实现了一个名为ArcFace的类,它是一个PyTorch模块,可以用于训练人脸识别模型。该类的构造函数接受几个参数:embedding_size表示每个人脸图像嵌入的向量大小,num_classes表示人脸库的人数,margin表示ArcFace的余弦相似度边界,scale表示每次前向传递时对余弦相似度的缩放因子。 该类的forward()方法接受两个参数:embeddings表示一个大小为(batch_size, embedding_size)的张量,其包含了一批人脸图像的嵌入向量;labels表示一个大小为(batch_size,)的张量,其包含了每个嵌入向量对应的人脸ID。该方法首先将嵌入向量和权重向量归一化,然后使用余弦相似度计算输入向量和权重向量之间的相似度得分。然后,对于每个嵌入向量,它的相似度得分被缩放(scale)和减去一个边界(margin),以获得最终的logit。最后,使用softmax交叉熵损失函数计算损失。 该模型的训练过程通常是使用随机梯度下降(SGD)优化器来最小化损失函数。在每个训练步骤,模型首先将输入图像传递到卷积神经网络,然后将得到的嵌入向量传递给ArcFace模块进行训练。
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值