在看trm源码的时候关于transformer position encoding的部分不能理解,记录一下以免以后要用到。(比较浅显,基本根据论文公式反推的)
# For positional encoding
num_timescales = self.hidden_size // 2##一半余弦,一半正弦
max_timescale = 10000.0
min_timescale = 1.0##max_timescale min_timescale是时间尺度的上下界
##以上:计算时间尺度
log_timescale_increment = (
math.log(float(max_timescale) / float(min_timescale)) /
max(num_timescales - 1, 1))##感觉是(max_timescale-min_timescale)取对数。
##在对数空间中相邻时间尺度之间的增量
##计算时间尺度的增量
inv_timescales = min_timescale * torch.exp(
torch.arange(num_timescales, dtype=torch.float32) *
-log_timescale_increment)
##计算时间尺度的值 ##inv_timescales是时间尺度的倒数
self.register_buffer('inv_timescales', inv_timescales)
def get_position_encoding(self, x):
max_length = x.size()[1]
position = torch.arange(max_length, dtype=torch.float32,
device=x.device)
##position=[0.,1.,2.,……,max_length-1.]
scaled_time = position.unsqueeze(1) * self.inv_timescales.unsqueeze(0)
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)],
dim=1)
signal = F.pad(signal, (0, 0, 0, self.hidden_size % 2))
signal = signal.view(1, max_length, self.hidden_size)
return signal
首先看下面这张图,便于我们直观地理解pe的构成:
可以自己输出一下代码,发现torch.sin(scaled_time), torch.cos(scaled_time)的维度都是N*D/2(N是序列的最长长度,D是hidden_size),所以signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)],dim=1)后维度就是N*D,也就是对每个pos的token的每一个隐藏层维度都进行一个位置编码(其实位置编码要和词嵌入向量做加法来加入位置信息,我们知道词嵌入向量的维度是B*N*D,那么位置编码1*N*D才能做加法)
然后,我不太理解为什么下面这段代码要取个指数函数。
inv_timescales = min_timescale * torch.exp(
torch.arange(num_timescales, dtype=torch.float32) *
-log_timescale_increment)
其实有下面这段推导:
下面介绍一下各参数含义:
num_timescales=256 ##也就是d/2
max_timescale=10000.0 ##也就是公式中那个10000
log_timescale_increment ##公式中的(2/d)*ln10000
max(num_timescales - 1, 1) ##防止分母为0
inv_timescales ##exp((2*i/d)*ln10000)
def get_position_encoding(self, x):
max_length = x.size()[1] ##x传入的参数是input B*N*D,所以这里取的是N
position = torch.arange(max_length, dtype=torch.float32,
device=x.device)
##position=tensor([0.,1.,2.,……,max_length-1.]),维度是N
scaled_time = position.unsqueeze(1) * self.inv_timescales.unsqueeze(0)
##position.unsqueeze(1)后维度是N*1,inv_timescales.unsqueeze(0)后维度是1*D/2
##所以scaled_time维度是N*D/2
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)],
dim=1)
signal = F.pad(signal, (0, 0, 0, self.hidden_size % 2))
##signal维度是N*D
signal = signal.view(1, max_length, self.hidden_size)
##把signal拉成(1*N*D)
return signal
另外举两个例子来理解一下torch.nn.functional.pad():
直接看图
对于一个二维的矩阵 pad中的tuple参数(a,b,c,d)代表
(a,b)表示对矩阵的倒数第一个维度做padding操作,在原tensor的左侧pad a个,右侧pad b个。
(c,d)表示对矩阵的倒数第二个维度做padding操作,在原tensor的左侧pad c个,右侧pad d个。