Transformer中嵌入表示 + 位置编码的实现
import torch
import math
from torch import nn
class EmbeddingWithPosition(nn.Module):
"""
vocab_size:词表大小
emb_size: 词向量维度
seq_max_len: 句子最大长度 (人为设定,例如GPT2的最大长度是1024)
"""
def __init__(self, vocab_size, emb_size, dropout=0.1, seq_max_len=5000):
self.seq_emb = nn.Embedding(vocab_size, emb_size)
position_idx = torch.arange(0, seq_max_len, dtype=torch.float).unsqueeze(-1)
position_emb_fill = position_idx * torch.exp(-torch.arange(0, emb_size, 2) * math.log(10000.0) / emb_size)
position_emb = torch.zeros(seq_max_len, emb_size)
position_emb[:, 0::2] = torch.sin(position_emb_fill)
position_emb[:, 1::2] = torch.cos(position_emb_fill)
self.register_buffer('pos_encoding', position_emb)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.seq_emb(x)
x += self.pos_encoding.unsqueeze(0)
return self.dropout(x)
自己动手实现易懂版本:
assert 10 % 2 == 0, "wrong assert"
import torch
import torch
def creat_pe_absolute_sincos_embedding_gov(n_pos_vec, dim):
assert dim % 2 == 0, "wrong dim"
position_embedding = torch.zeros(n_pos_vec.numel(), dim, dtype=torch.float)
omega = torch.arange(dim//2, dtype=torch.float)
omega /= dim/2.
omega = 1./(10000**omega)
sita = n_pos_vec[:,None] @ omega[None,:]
emb_sin = torch.sin(sita)
emb_cos = torch.cos(sita)
position_embedding[:,0::2] = emb_sin
position_embedding[:,1::2] = emb_cos
return position_embedding
def create_pe_absulute_sincos_embedding(n_pos_vec, dim):
"""
绝对位置编码
:param n_pos_vec: 位置编码的长度向量
:param dim: 词向量的维度
:return: 位置编码
"""
assert dim % 2 == 0, "dim must be even"
position_embedding = torch.zeros(n_pos_vec.numel(), dim, dtype=torch.float)
omega = torch.arange(dim // 2, dtype=torch.float)
omega *= 2
omega /= dim
omega = torch.pow(10000, omega)
omega = 1 / omega
omega = omega
print("n_pos_vec shape:",n_pos_vec.unsqueeze(1).shape)
print("omega shape:", omega.shape).squeeze
position_embedding[:, 0::2] = torch.sin(n_pos_vec.unsqueeze(1) * omega)
position_embedding[:, 1::2] = torch.cos(n_pos_vec.unsqueeze(1) * omega)
return position_embedding
if __name__ == "__main__":
n_pos = 4
dim = 8
n_pos_vec = torch.arange(n_pos, dtype=torch.float)
position_embeddding = create_pe_absulute_sincos_embedding(n_pos_vec, dim)
position_embeddding_1 = creat_pe_absolute_sincos_embedding_gov(n_pos_vec, dim)
print(position_embeddding == position_embeddding_1)
print("position embedding shape:", position_embeddding.shape)
参考版本