sliding window attention是为了解决在输出序列长度sequence length很大的时候attention计算量爆炸增长的问题。
用一句话来总结sliding window attention其实就是:每一个token只和包含其本身在内的前W个token做Attention。最简单的实现其实就是给不需要计算attention的其它token都加上一个mask就可以了,是不是非常简单?
用图片更直观一些,如下(图片来源:图解Mixtral 8 * 7b推理优化原理与源码实现):
核心代码如下:
def scaled_dot_product_attention(q, k, v, window_size, mask=None):
matmul_qk = torch.matmul(q, k.transpose(-1, -2))
dk = torch.tensor(k.shape[-1], dtype=torch.float32)
scaled_attention_logits = matmul_qk / torch.sqrt(dk)
if mask is not None:
scaled_attention_logits += mask * -