Transformer位置编码代码讲解

在看trm源码的时候关于transformer position encoding的部分不能理解,记录一下以免以后要用到。(比较浅显,基本根据论文公式反推的)

# For positional encoding
        num_timescales = self.hidden_size // 2##一半余弦,一半正弦
        max_timescale = 10000.0
        min_timescale = 1.0##max_timescale min_timescale是时间尺度的上下界
        ##以上:计算时间尺度

        log_timescale_increment = (
            math.log(float(max_timescale) / float(min_timescale)) /
            max(num_timescales - 1, 1))##感觉是(max_timescale-min_timescale)取对数。
        ##在对数空间中相邻时间尺度之间的增量
        ##计算时间尺度的增量

        inv_timescales = min_timescale * torch.exp(
            torch.arange(num_timescales, dtype=torch.float32) *
            -log_timescale_increment)
        ##计算时间尺度的值  ##inv_timescales是时间尺度的倒数

        self.register_buffer('inv_timescales', inv_timescales)

 def get_position_encoding(self, x):
        max_length = x.size()[1]
        position = torch.arange(max_length, dtype=torch.float32,
                                device=x.device)
        ##position=[0.,1.,2.,……,max_length-1.]
        scaled_time = position.unsqueeze(1) * self.inv_timescales.unsqueeze(0)
        signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)],
                           dim=1)
        signal = F.pad(signal, (0, 0, 0, self.hidden_size % 2))
        signal = signal.view(1, max_length, self.hidden_size)
        return signal

首先看下面这张图,便于我们直观地理解pe的构成:

可以自己输出一下代码,发现torch.sin(scaled_time), torch.cos(scaled_time)的维度都是N*D/2(N是序列的最长长度,D是hidden_size),所以signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)],dim=1)后维度就是N*D,也就是对每个pos的token的每一个隐藏层维度都进行一个位置编码(其实位置编码要和词嵌入向量做加法来加入位置信息,我们知道词嵌入向量的维度是B*N*D,那么位置编码1*N*D才能做加法)

然后,我不太理解为什么下面这段代码要取个指数函数。

inv_timescales = min_timescale * torch.exp(
            torch.arange(num_timescales, dtype=torch.float32) *
            -log_timescale_increment)

其实有下面这段推导:

下面介绍一下各参数含义:

num_timescales=256 ##也就是d/2
max_timescale=10000.0 ##也就是公式中那个10000
log_timescale_increment ##公式中的(2/d)*ln10000
max(num_timescales - 1, 1) ##防止分母为0
inv_timescales ##exp((2*i/d)*ln10000)

    def get_position_encoding(self, x):
        max_length = x.size()[1] ##x传入的参数是input B*N*D,所以这里取的是N
        position = torch.arange(max_length, dtype=torch.float32,
                                device=x.device)
        ##position=tensor([0.,1.,2.,……,max_length-1.]),维度是N
        scaled_time = position.unsqueeze(1) * self.inv_timescales.unsqueeze(0)
        ##position.unsqueeze(1)后维度是N*1,inv_timescales.unsqueeze(0)后维度是1*D/2
        ##所以scaled_time维度是N*D/2
        signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)],
                           dim=1)
        signal = F.pad(signal, (0, 0, 0, self.hidden_size % 2))
        ##signal维度是N*D
        signal = signal.view(1, max_length, self.hidden_size)
        ##把signal拉成(1*N*D)
        return signal

另外举两个例子来理解一下torch.nn.functional.pad():

直接看图

对于一个二维的矩阵 pad中的tuple参数(a,b,c,d)代表

(a,b)表示对矩阵的倒数第一个维度做padding操作,在原tensor的左侧pad a个,右侧pad b个。

(c,d)表示对矩阵的倒数第二个维度做padding操作,在原tensor的左侧pad c个,右侧pad d个。

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
Transformer代码讲解将包括以下几个模块的原理和代码解析: 1. 注意力机制(Attention Mechanism):在Transformer中,注意力机制用于计算输入序列中不同位置之间的相对重要性,并为输出序列的每个位置分配相应的权重。注意力机制的实现通常涉及到查询、键和值的计算以及计算注意力权重。 2. 多头注意力(Multi-head Attention):多头注意力是一种改进的注意力机制,在Transformer中被广泛使用。它通过将多个注意力头并行运行来捕捉不同的表示子空间,从而提高模型的表示能力。多头注意力的实现包括对注意力机制进行多次计算,并将结果进行拼接和线性变换。 3. 编码器(Encoder):编码器由多个相同的层堆叠而成,每个层都包含一个多头注意力子层和一个前馈神经网络子层。编码器用于对输入序列进行编码,捕获输入序列中的语义信息。 4. 解码器(Decoder):解码器也由多个相同的层堆叠而成,每个层包含一个多头注意力子层、一个编码器-解码器注意力子层和一个前馈神经网络子层。解码器用于生成输出序列,它利用编码器的输出和自身的历史输出来预测下一个输出。 5. 位置编码(Positional Encoding):由于Transformer没有像循环神经网络和卷积神经网络那样的显式位置信息,因此需要引入位置编码来捕捉输入序列中的位置信息。位置编码的实现通常使用正弦和余弦函数进行计算。 以上是Transformer代码的主要讲解内容。通过深入理解这些模块的原理和代码,可以更好地掌握Transformer模型的工作原理和实现方式。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值