大模型中的三角位置编码实现

Transformer中嵌入表示 + 位置编码的实现

import torch
import math
from torch import nn



# 词嵌入位置编码实现
class EmbeddingWithPosition(nn.Module):
    """
    vocab_size:词表大小
    emb_size: 词向量维度
    seq_max_len: 句子最大长度 (人为设定,例如GPT2的最大长度是1024) 
    """


    def __init__(self, vocab_size, emb_size, dropout=0.1, seq_max_len=5000):

        self.seq_emb = nn.Embedding(vocab_size, emb_size) # 序列中每个token的embedding向量表示
        #  位置编码实现 (硬编码方式)
        position_idx = torch.arange(0, seq_max_len, dtype=torch.float).unsqueeze(-1)
        position_emb_fill = position_idx * torch.exp(-torch.arange(0, emb_size, 2) * math.log(10000.0) / emb_size) # 三角位置编码实现

        position_emb = torch.zeros(seq_max_len, emb_size) # 位置编码 emb_size是嵌入维度大小
        position_emb[:, 0::2] = torch.sin(position_emb_fill)
        position_emb[:, 1::2] = torch.cos(position_emb_fill)

        self.register_buffer('pos_encoding', position_emb) # 固定参数,不需要train

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):

        x = self.seq_emb(x) # 嵌入层表示 (batch_size, seq_len, emb_size)
        # x = x + self.pos_encoding.unsqueeze(0)[:,:x.size()[1],:] # 添加位置编码
        x += self.pos_encoding.unsqueeze(0)
        return self.dropout(x)
    

自己动手实现易懂版本:

assert 10 % 2 == 0,  "wrong assert"
# 如果前面判断正确的话,则不会引发异常;否则,则会引发异常

import torch

import torch
def creat_pe_absolute_sincos_embedding_gov(n_pos_vec, dim):
  assert dim % 2 == 0, "wrong dim"
  position_embedding = torch.zeros(n_pos_vec.numel(), dim, dtype=torch.float)

  omega = torch.arange(dim//2, dtype=torch.float)
  omega /= dim/2.
  omega = 1./(10000**omega)

  sita = n_pos_vec[:,None] @ omega[None,:]
  emb_sin = torch.sin(sita)
  emb_cos = torch.cos(sita)

  position_embedding[:,0::2] = emb_sin
  position_embedding[:,1::2] = emb_cos

  return position_embedding



def create_pe_absulute_sincos_embedding(n_pos_vec, dim):
    """
    绝对位置编码
    :param n_pos_vec: 位置编码的长度向量
    :param dim: 词向量的维度
    :return: 位置编码
    """
    assert dim % 2 == 0, "dim must be even"
    position_embedding = torch.zeros(n_pos_vec.numel(), dim, dtype=torch.float) # 三角函数位置编码

    omega = torch.arange(dim // 2, dtype=torch.float) # 0 ~ i, max_i: dim // 2
    omega *= 2
    omega /= dim 

    omega = torch.pow(10000, omega) # 10000^(2i/dim)
    omega = 1 / omega
    omega = omega

    print("n_pos_vec shape:",n_pos_vec.unsqueeze(1).shape)
    print("omega shape:", omega.shape).squeeze

    position_embedding[:, 0::2] = torch.sin(n_pos_vec.unsqueeze(1) * omega) # 偶数位置
    position_embedding[:, 1::2] = torch.cos(n_pos_vec.unsqueeze(1) * omega) # 奇数位置

    return position_embedding


if __name__ == "__main__":
    n_pos = 4
    dim = 8
    n_pos_vec = torch.arange(n_pos, dtype=torch.float)
    position_embeddding = create_pe_absulute_sincos_embedding(n_pos_vec, dim)
    position_embeddding_1 = creat_pe_absolute_sincos_embedding_gov(n_pos_vec, dim)
    print(position_embeddding == position_embeddding_1)
    print("position embedding shape:", position_embeddding.shape)

参考版本

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值