正确理解PE

文章探讨了位置和顺序在Transformer模型中的重要性,介绍了如何通过正弦、余弦函数的线性组合实现位置编码以捕捉依赖关系。对比了固定编码(encoding)与可学习位置(embedding),并提供了Pytorch中的PositionalEncoding实现代码示例。
摘要由CSDN通过智能技术生成

1.为什么?

位置和顺序对于一些任务十分重要,例如理解一个句子、一段视频。位置和顺序定义了句子的语法、视频的构成,它们是句子和视频语义的一部分。Transformer使用MHSA(Multi-Head Self-Attention),从而避免使用了RNN的递归方法,加快了训练时间,同时,它可以捕获句子中的长依赖关系,能够应对更长的输入。简而言之,Trans架构需要对输入的Tokens添加位置编码,以辅助模型学习不同patch之间相对位置关系,比如词序、图像分割的patch。

2.怎么做的?

首先paper中是使用正弦、余弦函数线性组合来对每一个维度进行位置编码,这里叫做position encoding。这样做的好处是:在满足需求的条件下节省内存空间!!对于有限的显存来讲,内存空间太太太重要了!N卡的显存大小与金钱正相关!!

其次,这种sin-cos组合可以更容易得到位置关系,证明请看:知乎 - 安全中心

再次,原文中采用add,而不是concat,将patch embedding和position encoding进行连接,然后送入transformer encoder进行训练。具体add和concat哪个效果更好,其实差不多!可能的猜想是,embedding之后的向量属于高维向量,position encoding后的与之同维,也是高维,二者在高维空间内分别占据部分子空间,并相互正交。所以add or concat不影响结果,add还可以减少参数量。详细过程期待大佬可以理论证明一下~

最后,代码实现:参考Transformer实现以及Pytorch源码解读(三)-位置编码Position Encoding——史上最容易理解_pytorch中的position_encoding是什么意思-CSDN博客

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=512):
        #d_model是每个词embedding后的维度
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term2 = torch.pow(torch.tensor(10000.0),torch.arange(0, d_model, 2).float()/d_model)
        div_term1 = torch.pow(torch.tensor(10000.0),torch.arange(1, d_model, 2).float()/d_model)
        #高级切片方式,即从0开始,两个步长取一个。即奇数和偶数位置赋值不一样。直观来看就是每一句话的
        pe[:, 0::2] = torch.sin(position * div_term2)
        pe[:, 1::2] = torch.cos(position * div_term1)
        #这里是为了与x的维度保持一致,释放了一个维度
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)
    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return x

区分:

embedding 和encoding:encoding是固定编码,直接公式计算得到,如transformer; 而embedding是可学习位置,是参数,如Bert.

  • 17
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值