先看论文中的图:
显然,AutoCorrelation 完全可以替换 self-attention 。输入都是 q,k, v ,输出是一个 V。
直接看代码:
def forward(self, queries, keys, values, attn_mask):
B, L, H, E = queries.shape
_, S, _, D = values.shape
if L > S:
zeros = torch.zeros_like(queries[:, :(L - S), :]).float()
values = torch.cat([values, zeros], dim=1)
keys = torch.cat([keys, zeros], dim=1)
else:
values = values[:, :L, :, :]
keys = keys[:, :L, :, :]
# period-based dependencies
q_fft = torch.fft.rfft(queries.permute(0, 2, 3, 1).contiguous(), dim=-1) # 32, 4, 128, 49
k_fft = torch.fft.rfft(keys.permute(0, 2, 3, 1).contiguous(), dim=-1) # 32, 4, 128, 49
res = q_fft * torch.conj(k_fft) # 32, 4, 128, 49
corr = torch.fft.irfft(res, dim=-1) # 32, 4, 128, 96
一般情况下,queries, keys, values 的shape都是一样的,比如:[32, 96, 4, 128]。
当然,解码器的第二个AC,keys和 values 的shape 一样,比如:[32, 96, 4, 128]。queries 的为[32, 48+192, 4, 128]。通过上面 if 语句会把 keys和 values 的长度补0 到 queries 一样长。
本文以 queries, keys, values 的shape:[32, 96, 4, 128] 说明。
显然, fft 并没有可说的。主要关注的在 time_delay_agg_training 函数。
V = self.time_delay_agg_training(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)
def time_delay_agg_training(self, values, corr): # 32 4 128 96
"""
SpeedUp version of Autocorrelation (a batch-normalization style design)
This is for the training phase.
"""
# 把批次看成1 进行分析,一目了然
head = values.shape[1] # 4
channel = values.shape[2] # 128
length = values.shape[3] # 96
# find top k
top_k = int(self.factor * math.log(length)) # 1* ln(96) = 4
mean_value = torch.mean(torch.mean(corr, dim=1), dim=1) # 32, 96 每次求mean 那个维度就会消失
print(torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)) # 4个值
index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)[1] # 4 假设得到:[3, 4, 5, 2] # 这里[1]后得到的是index ## 批次维度求均值得到 shape: 96 再取 top_k
weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1) # 32 4 把96个里面最大的4个值取出来
# update corr
tmp_corr = torch.softmax(weights, dim=-1) # 32 4 这4个值归一化
# aggregation
tmp_values = values # 32 4 128 96
delays_agg = torch.zeros_like(values).float() # 32 4 128 96 都是0
for i in range(top_k):
pattern = torch.roll(tmp_values, -int(index[i]), -1) # 32 4 128 96 从序列那个维度移动 # index 对应移动的步长
delays_agg = delays_agg + pattern * \
(tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)) # pattern 乘的是 那个 R_q,k 后面部分表示对应元素相乘 后面的就是那个权重 unsqueeze repeat 是为了形状相同 可以相乘 # tmp_corr[:, i] 这里不考虑batch_size 时是一个数 把这个数repeat 96次
return delays_agg # 32 4 128 96
上图,假设top_k = 2, 序列长度 = 10。
该模型更新完毕。
更多参考:请参考前文论文阅读部分的参考文献。作者回答了为什么用 torch.roll 。使用 top_k的目的还是降低计算复杂度。