Transformer详解(2)-位置编码

Transformer结构不再使用基于循环的方式建模文本输入,序列中不再有任何信息能够提示模型单词之间的相对位置关系。因此,Transformer自身无法感知位置信息,需要输入层的额外位置信息。序列中每一个单词所在的位置都对应一个向量。这一向量会与单词表示对应相加并送入后续模块中做进一步处理。在训练过程中,模型会自动的学习到如何利用这部分位置信息。

位置编码公式

偶数位置用sin,奇数位置用cos. d_model 表示token的维度;pos表示token在序列中的位置;i表示每个token编码的第i个位置,属于[0,d_model)。

torch实现

import math
import torch
from torch import nn
from torch.autograd import Variable
import matplotlib.pyplot as plt


class PositionalEncoder(nn.Module):
    def __init__(self, max_seq_len=50, d_model=128):
        super().__init__()
        self.d_model = d_model  # d_model 表示token的维度

        pe = torch.zeros(max_seq_len, d_model)  # max_seq_len * d_model 的二维张量   例如: 50*128
        for pos in range(max_seq_len):  # 重新初始化
            for i in range(0, d_model, 2):
                pe[pos, i] = math.sin(pos / (10000 ** (i / d_model)))
                pe[pos, i + 1] = math.cos(pos / (10000 ** (i / d_model)))

        pe = pe.unsqueeze(0)  # 1*50*128
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x * math.sqrt(self.d_model)

        seq_len = x.size(1)

        x = x + Variable(self.pe[:, :seq_len], requires_grad=False).cuda()
        return x

if __name__ == '__main__':
    positional_encoder = PositionalEncoder(50, 128)
    plt.pcolormesh(positional_encoder.pe.numpy()[0], cmap='RdBu')
    plt.xlabel('Depth')  # 50
    plt.xlim((0, 128))
    plt.ylabel('Position')  # 128
    plt.colorbar()
    plt.show() 

位置编码可视化

位置编码可视化

  • 14
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值