前言:对于 transformer,informer 做出的改进第一点是使用概率注意力机制(ProbSparse attention )代替 self attention
ProbSparse attention 是 informer 的核心,本节主要讲代码原理
本代码中
- Q:查询张量,形状为 [B, H, L_Q, D]
- K:键张量,形状为 [B, H, L_K, E]
- sample_k:采样的键的数量
- n_top:Top-K查询的数量,计算为 factor * ln(L_Q)
- V:值张量,形状为 [B, H, L_V, D]
B表示批次大小,H表示头的数量,L_K表示键的长度,E 表示键的嵌入维度,L_Q 表示查询的长度,D 表示查询/值的嵌入维度,L_V表示值的长度
ProbSparse attention 分为两步
一、计算注意力分数
1. 通过在K上添加新维度拓展到 [B, H, L_Q, L_K, E] 形状
2. 根据 sample_k 对其长度维度上进行采样得到 [B, H, L_Q, sample_k, E] 形状,计算 Q_K_sample = Q * K.T
3. 根据 Q_K_sample(形状为 [B, H, L_Q, sample_k] )计算 M 以及 M_top,基于M_top得到新的 Q(形状为 [B, H, n_top, D] )
4. 使用新的 Q 计算注意力分数(形状为 [B, H, n_top, L_K] )
代码如下
class ProbAttention(nn.Module):
def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
super(ProbAttention, self).__init__()
"""
mask_flag:决定是否使用掩码,默认为True
factor:缩放因子,用于计算采样的数量 factor * ln(Q长度)
scale:决定是否缩放
attention_dropout:注意力机制中的dropout率
output_attention:决定是否输出注意力权重
"""
self.factor = factor
self.scale = scale
self.mask_flag = mask_flag
self.output_attention = output_attention
self.dropout = nn.Dropout(attention_dropout)
def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q)
# Q [B, H, L, D]
B, H, L_K, E = K.shape
_, _, L_Q, _ = Q.shape
# calculate the sampled Q_K
# K_expand.shape
# torch.Size([32, 8, 96, 96, 64])
# torch.Size([96, 25])
# K_sample.shape
# torch.Size([32, 8, 96, 25, 64])
# Q_K_sample.shape
# torch.Size([32, 8, 96, 25])
# 通过在 K 上添加一个新的维度并扩展到 [B, H, L_Q, L_K, E] 的形状
K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
index_sample = torch.randint(L_K, (L_Q, sample_k)) # real U = U_part(factor*ln(L_k))*L_q
# 根据 index_sample 对 K_expand 进行采样,得到一个形状为 [B, H, L_Q, sample_k, E] 的张量
K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]
# 计算查询和采样的键之间的点积,得到一个形状为 [B, H, L_Q, sample_k] 的张量
Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze(-2)
# 找到具有稀疏度度量的Top-K查询
M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
M_top = M.topk(n_top, sorted=False)[1]
# 使用减少的Q计算Q_K
Q_reduce = Q[torch.arange(B)[:, None, None],
torch.arange(H)[None, :, None],
M_top, :] # factor*ln(L_q)
Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k
# M.shape
# torch.Size([32, 8, 96])
# M_top.shape
# torch.Size([32, 8, 25])
# Q_reduce.shape
# torch.Size([32, 8, 25, 64])
# Q_K.shape
# torch.Size([32, 8, 25, 96])
return Q_K, M_top
二、计算最后输出
1. 基于注意力分数(形状为 [B, H, n_top, L_K] )和初始化的V计算得到结果(形状为 [B, H, n_top, D] )
2. 基于V的计算填充最后输出(若不使用掩码,则计算 V 在最后一个维度上的均值,形状为 [B, H, D],并在新添加的维度上扩展到 [B, H, L_Q, D];若使用掩码,则计算 V 在最后一个维度上的累计和,形状为 [B, H, L_V, D])
3. 最后输出形状为 [B, H, L_Q, D]
注意:最后输出形状第三维度中,n_top < L_Q,n_top部分基于1计算可得,L_Q - n_top部分基于2计算可得
代码如下
def _get_initial_context(self, V, L_Q):
B, H, L_V, D = V.shape
if not self.mask_flag:
# 计算 V 在最后一个维度上的均值,形状为 [B, H, D]
V_sum = V.mean(dim=-2)
# 将 V_sum 在新添加的维度上扩展到 [B, H, L_Q, D],并返回它的副本
contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone()
else: # 使用掩码
assert(L_Q == L_V)
# 断言查询长度 L_Q 和值的长度 L_V 相等,即只适用于自注意力(self-attention)
contex = V.cumsum(dim=-2)
# V.shape
# torch.Size([32, 8, 72, 64])
# contex.shape
# torch.Size([32, 8, 72, 64])
return contex
def _update_context(self, context_in, V, scores, index, L_Q, attn_mask):
B, H, L_V, D = V.shape
if self.mask_flag:
attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device)
scores.masked_fill_(attn_mask.mask, -np.inf)
# 使用softmax函数计算注意力权重,形状为 [B, H, n_top, L_V]
attn = torch.softmax(scores, dim=-1)
# 使用注意力权重 attn 和值 V 进行加权求和
# 并将结果赋值到 context_in 的对应位置,对应位置形状为 [B, H, n_top, D]
context_in[torch.arange(B)[:, None, None],
torch.arange(H)[None, :, None],
index, :] = torch.matmul(attn, V).type_as(context_in)
if self.output_attention:
# 初始化一个全为1的张量,并将其除以 L_V,形成形状为 [B, H, L_V, L_V]
attns = (torch.ones([B, H, L_V, L_V])/L_V).type_as(attn).to(attn.device)
# 将计算得到的注意力权重 attn 赋值到 attns 的对应位置
attns[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = attn
return (context_in, attns)
else:
return (context_in, None)
def forward(self, queries, keys, values, attn_mask):
B, L_Q, H, D = queries.shape
_, L_K, _, _ = keys.shape
queries = queries.transpose(2,1)
keys = keys.transpose(2,1)
values = values.transpose(2,1)
U_part = self.factor * np.ceil(np.log(L_K)).astype('int').item() # c*ln(L_k)
u = self.factor * np.ceil(np.log(L_Q)).astype('int').item() # c*ln(L_q)
U_part = U_part if U_part<L_K else L_K
u = u if u<L_Q else L_Q
scores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u)
# 加入缩放系数
scale = self.scale or 1./sqrt(D)
if scale is not None:
scores_top = scores_top * scale
# 初始化上下文向量
context = self._get_initial_context(values, L_Q)
# 更新上下文向量
context, attn = self._update_context(context, values, scores_top, index, L_Q, attn_mask)
return context.transpose(2,1).contiguous(), attn
掩码计算代码如下
class ProbMask():
def __init__(self, B, H, L, index, scores, device="cpu"):
"""
创建一个形状为 (L, scores.shape[-1]) 的全为 True 的布尔张量
.to(device): 将张量移动到指定的设备(CPU 或 GPU)
.triu(1): 获取该张量的上三角部分(不包括主对角线),下三角部分设为 False
"""
_mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1)
"""
_mask[None, None, :]: 在前两个维度上增加两个维度,使其形状变为 (1, 1, L, scores.shape[-1])。
.expand(B, H, L, scores.shape[-1]): 扩展张量,使其形状变为 (B, H, L, scores.shape[-1])
"""
_mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1])
"""
torch.arange(B)[:, None, None]: 创建一个形状为 (B, 1, 1) 的张量,用于索引批次维度。
torch.arange(H)[None, :, None]: 创建一个形状为 (1, H, 1) 的张量,用于索引注意力头维度。
index: 索引序列维度
"""
indicator = _mask_ex[torch.arange(B)[:, None, None],
torch.arange(H)[None, :, None],
index, :].to(device)
self._mask = indicator.view(scores.shape).to(device)
@property
def mask(self):
return self._mask