Transformer位置编码层的实现

一、前言

在这篇文章我们将实现位置编码层,跟传统的 LSTM序列模型不同,在 Transformer 编码结构中,均是基于全连接层实现,Linear层没有捕捉位置信息的能力,因此纯粹的 Attention 模块是无法捕捉输入顺序的,因此需要在 Embedding 层后加入位置编码器,将词汇位置不同,可能会产生不同语义信息,加入到词嵌入张量中。Transformer 采用的是正余弦的绝对位置编码,这种编码方式可以保证,不同位置在所有维度上不会被编码到完全一样的值,从而使每个位置都获得独一无二的编码。通俗理解,就是将位置信息编码成向量后,每个位置对应的向量都是不同的。

二、代码实现

方法1:

该方法主要是复现上述公式,pos是词的位置,dmodel是词向量化后的维度。

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000,dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.pe = torch.zeros(max_len, d_model)
        # max_len=5000 通常一句话的长度值不会超过5000
        for pos in range(max_len):
            for j in range(d_model):
                angle = pos / math.pow(10000, (j // 2) * 2 / d_model)
                if j % 2 == 0:
                    self.pe[pos][j] = math.sin(angle)
                else:
                    self.pe[pos][j] = math.cos(angle)

    def forward(self, x):
        return self.dropout(x + self.pe[:x.size(1)])

方法2: 

该方法是基于原始公式的变形,因为原始公式的实现需要两层循环,时间复杂度为n的平方,公式变形后可以降低时间复杂度。

class PositionalEncoding(nn.Module):
    def __init__(self,d_model,dropout=0.1,max_len=5000):
        super().__init__()
        self.dropout=nn.Dropout(dropout)
        pe = torch.zeros(max_len,d_model)
        # 位置和除数
        position = torch.arange(0,max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0,d_model,2) * -math.log(10000) / d_model)
        # 修改pe矩阵的值
        pe[:,0::2] = torch.sin(position*div_term)
        pe[:,1::2] = torch.cos(position*div_term)
        # 扩展 batch 维度
        pe = pe.unsqueeze(0)
        # 存储为不需要计算梯度的参数
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + Variable(self.pe[:, :x.size(1)], requires_grad=False)
        return self.dropout(x)

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

木珊数据挖掘

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值