前言:对于 transformer,informer 做出的改进第一点是使用概率注意力机制(ProbSparse attention )代替 self attention
ProbSparse attention 是 informer 的核心,本节主要讲代码原理
本代码中
- Q:查询张量,形状为 [B, H, L_Q, D]
- K:键张量,形状为 [B, H, L_K, E]
- sample_k:采样的键的数量
- n_top:Top-K查询的数量,计算为 factor * ln(L_Q)
- V:值张量,形状为 [B, H, L_V, D]
B表示批次大小,H表示头的数量,L_K表示键的长度,E 表示键的嵌入维度,L_Q 表示查询的长度,D 表示查询/值的嵌入维度,L_V表示值的长度
ProbSparse attention 分为两步
一、计算注意力分数
1. 通过在K上添加新维度拓展到 [B, H, L_Q, L_K, E] 形状
2. 根据 sample_k 对其长度维度上进行采样得到 [B, H, L_Q, sample_k, E] 形状,计算 Q_K_sample = Q * K.T
3. 根据 Q_K_sample(形状为 [B, H, L_Q, sample_k] )计算 M 以及 M_top,基于M_top得到新的 Q(形状为 [B, H, n_top, D] )
4. 使用新的 Q 计算注意力分数(形状为 [B, H, n_top, L_K] )
代码如下
class ProbAttention(nn.Module):
def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
super(ProbAttention, self).__init__()
"""
mask_flag:决定是否使用掩码,默认为True
factor:缩放因子,用于计算采样的数量 factor * ln(Q长度)
scale:决定是否缩放
attention_dropout:注意力机制中的dropout率
output_attention:决定是否输出注意力权重
"""
self.factor = factor