Transformer中的位置编码(PE,position encoding)
参考链接
https://blog.csdn.net/Flying_sfeng/article/details/100996524
https://blog.csdn.net/u012526436/article/details/86295971
transformer模型是当前大红大热的语言模型,今天要讲解的是transformer中的positional encoding(位置编码)。
我们知道,transformer模型的attention机制并没有包含位置信息,即一句话中词语在不同的位置时在transformer中是没有区别的,这当然是不符合实际的。
为了处理这个问题,transformer给encoder层和decoder层的输入添加了一个额外的向量Positional Encoding,维度和embedding的维度一样,这个向量采用了一种很独特的方法来让模型学习到这个值,这个向量能决定当前词的位置,或者说在一个句子中不同的词之间的距离。这个位置向量的具体计算方法有很多种,论文中的计算方法如下
其中,PE为二维矩阵,大小跟输入embedding的维度一样,行表示词语,列表示词向量;pos 表示词语在句子中的位置;dmodel表示词向量的维度;i表示词向量的位置。因此,上述公式表示在每个词语的词向量的偶数位置添加sin变量,奇数位置添加cos变量,以此来填满整个PE矩阵,然后加到input embedding中去,这样便完成位置编码的引入了。
具体实现的代码如下:
class PositionalEncoding(nn.Module):
"Implement the PE function."
def __init__(self, d_model, dropout, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
# Compute the positional encodings once in log space.
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) *
-(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + Variable(self.pe[:, :x.size(1)],
requires_grad=False)
return self.dropout(x)
其中,div_term是上述公式经过简单的数学变换得到的,具体如下:
我们来解读这个代码:
首先定义一个PE矩阵并零初始化,shape是(max_len,d_model)。后面假设max_len=5,d_model=4。
max_len=5
d_model=4
pe=torch.zeros(max_len,d_model)
然后我们定义一个position矩阵,shape是(max_len,1)
position=torch,arange(0.0,max_len).unsqueeze(1)
接下来定义div_term,这是一个一维向量,长度是dmodel/2
div_term = torch.exp(torch.arange(0, d_model, 2) *
-(math.log(10000.0) / d_model))
接下来稍微解释一下pe[ : , 0::2 ],这个表示所有行取奇数列。同理pe[ : , 0::2 ]表示取所有行偶数列。
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
这里就对每一行的每一维赋值,即做位置编码。
执行过程中可能会报错
RuntimeError: “exp” not implemented for ‘torch.LongTensor’
估计原因是pytorch的版本问题,解决方法为将torch.arange(0, d_model, 2)中的整型改为浮点型,即torch.arange(0.0, d_model, 2)