回顾transformer中绝对位置编码(absolute position embedding)
在transformer的实现中,所有的input tokens是无序的,是没法像RNN的方法一样学到token之间的位置顺序关系,但是自然语言一定是有语序在里面的,所以原始的transformer的代码实现里就提供了一种非常简单的位置编码,输入的是一个和max_sequence_length一样长的固定index序列,比如在max_sequence_length=64的情况下,输入的位置信息序列就是[0,1,2,3…62,63],然后通过一个embedding层让网络自己去学习位置的编码和表征。因为不管输入的文本是什么,位置编码永远是一个定值的序列(在不同的网络中仅长度会变化),所以这就是一种绝对位置编码方式。我个人对这种位置编码方式是否能提供有用的位置信息是存疑的,因为当网络训练好了以后,所有的input embedding相当于都加上一个定值,然而不同的input里面的位置和语义信息都是不同的,我认为这种加定值的方式并不能学到一些位置带来的语义上的差别,但毕竟增加了一定的参数量,可能聊胜于无吧。
贴一下transformer中绝对位置编码的实现,在modeling_bert.py中,实现也很简单:
class BertEmbeddings(nn.Module):
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) #初始化定义position_embedding的embedding层,其实底层就是一个fc
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
if position_ids is None:
position_ids = self.position_ids[:, :seq_length] #生成定值的序列
if position_ids.dtype is not torch.long:
position_ids = position_ids.to(torch.long)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
embeddings = inputs_embeds
embeddings += token_type_embeddings
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids) #输入的固定序列经过一个fc得到位置编码,直接point-wise加回到token embedding上
embeddings += position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
相对位置编码方法详解和公式推导
相对位置编码的方法主要是出自于Google的一篇paper《Self-Attention with Relative Position Representations》
先说一下transformer中的self-attention机制:
每一层transformer由h个attention heads组成,其中每个attention head的输入为n个元素组成的token序列 x = ( x 1 , . . . , x n ) x = (x_1,...,x_n) x=(x1,...