位置编码计算方式
公式解析
PE(pos,2i) = sin(pos / (10000^(2i/d))):这个公式计算的是位置编码向量中偶数维度(2i)的值。其中,pos是单词所在的位置,比如第一个单词的位置是0,第二个单词的位置是1,以此类推。2i表示位置编码向量中的偶数维度,例如当i=0时,2i就是第0维;当i=1时,2i就是第2维。d是位置编码的总维度,也就是位置编码向量的长度。
PE(pos,2i + 1) = cos(pos / (10000^(2i/d))):这个公式计算的是位置编码向量中奇数维度(2i+1)的值。计算方法和偶数维度类似,只是使用了余弦函数。
举例说明
假设位置编码的总维度d=4,我们要计算第3个单词(pos=2)的位置编码。
对于第0维(2i,i=0):PE(2,0) = sin(2 / (10000^(0/4))) = sin(2 / 1) = sin(2)
对于第1维(2i+1,i=0):PE(2,1) = cos(2 / (10000^(0/4))) = cos(2 / 1) = cos(2)
对于第2维(2i,i=1):PE(2,2) = sin(2 / (10000^(2/4))) = sin(2 / (10000^0.5)) = sin(2 / 100)
对于第3维(2i+1,i=1):PE(2,3) = cos(2 / (10000^(2/4))) = cos(2 / (10000^0.5)) = cos(2 / 100)
所以,第3个单词的位置编码向量就是[sin(2), cos(2), sin(2/100), cos(2/100)]。
计算位置编码的好处
保持语义信息
正余弦函数的值域是[-1, +1]。当我们将计算得到的位置编码与原词嵌入(也就是单词本身的向量表示)相加时,由于位置编码的值不会过大或过小,所以不会使相加后的结果偏离原词嵌入太远。这就保证了在加入位置信息的同时,不会破坏原有单词的语义信息。例如,原词嵌入是[0.5, 0.3, -0.2, 0.1],位置编码是[0.2, -0.1, 0.05, -0.03],相加后得到[0.7, 0.2, -0.15, 0.07],原词嵌入的主要特征仍然得以保留。
蕴含距离信息
依据三角函数的周期性和线性组合性质,第pos+k个位置编码可以表示为第pos个位置编码的线性组合。简单来说,就是通过位置编码可以推算出单词之间的相对距离。
例如,假设我们有位置编码PE(pos)和PE(pos+k),由于正余弦函数的周期性,PE(pos+k)的各个维度的值可以通过对PE(pos)的各个维度的值进行一定的线性变换(比如相位移动等)得到。这就意味着,当我们看到两个位置编码时,能够从中感知到这两个位置对应的单词之间的距离关系。这对于模型理解句子中单词的顺序和相对位置非常重要,比如在处理“我爱北京天安门”这样的句子时,模型可以通过位置编码知道“我”和“天安门”之间的距离,从而更好地理解句子的结构和语义。