paper:Efficient Attention: Attention with Linear Complexities
Code:https://github.com/cmsflash/efficient-attention
1、Efficient Attention
点积注意力 在计算机视觉与自然语言处理领域有着广泛的应用。然而,点积注意力的计算成本与输入大小呈二次增长。这种增长阻碍了在高分辨率输入上的应用。为了弥补 点积注意力 的这一缺陷,论文提出了一种新的高效注意力机制(Efficient Attention),EA 与 点积注意力 等价,但具有更少的计算开销。此外,EA 还可以极其灵活地集成到其他网络中。
论文中首先论证到点积注意力的关键缺陷,即过高的计算成本与开销。所以本文提出了一种更为有效的注意力机制,这是一种在数学上与 点积注意力 等价的注意力,但更快、更有效。在EA中,特征向量 X 仍然通过三个线性层以形成查询Q、键 K 和值V。EA 和点积注意力的不同之处在于:点积注意力首先使用 Q 与 K 合成像素级注意力图,然后使用 V 让每个像素将值与注意力图聚合;而 EA 则是改变了乘法顺序,首先通过将 K 与 V 聚合以形成全局上下文向量,再使用 Q 对上下文向量进行加权求和获得输出。
EA 带来了对注意力机制的新诠释。在点积注意力中,选择位置 i 作为参考位置,可以收集所有位置与位置的相似性,并形成该位置的注意力图 S_i 。注意力t图 S_i 表示位置i关注输入中的每个位置j的程度。S_i 上位置 j 的值越高,意味着位置 i 会更多地关注位置 j 。在点积注意力中,每个位置 i 都有这样的注意力图 S_i。
相比之下,高效注意力不会为每个位置生成注意力图。相反,它将 K 解释为注意力映射 K_t。每个K_t是一个全局注意力图,不对应于任何特定位置。它们中的每一个都对应于整个输入的语义方面。
EA 与 点积注意力 结构图:
2、代码实现
import torch
from torch import nn
from torch.nn import functional as f
class EfficientAttention(nn.Module):
def __init__(self, in_channels, key_channels=2, head_count=2, value_channels=2):
super().__init__()
self.in_channels = in_channels
self.key_channels = key_channels
self.head_count = head_count
self.value_channels = value_channels
self.keys = nn.Conv2d(in_channels, key_channels, 1)
self.queries = nn.Conv2d(in_channels, key_channels, 1)
self.values = nn.Conv2d(in_channels, value_channels, 1)
self.reprojection = nn.Conv2d(value_channels, in_channels, 1)
def forward(self, input_):
n, _, h, w = input_.size()
keys = self.keys(input_).reshape((n, self.key_channels, h * w))
queries = self.queries(input_).reshape(n, self.key_channels, h * w)
values = self.values(input_).reshape((n, self.value_channels, h * w))
head_key_channels = self.key_channels // self.head_count
head_value_channels = self.value_channels // self.head_count
attended_values = []
for i in range(self.head_count):
key = f.softmax(keys[
:,
i * head_key_channels: (i + 1) * head_key_channels,
:
], dim=2)
query = f.softmax(queries[
:,
i * head_key_channels: (i + 1) * head_key_channels,
:
], dim=1)
value = values[
:,
i * head_value_channels: (i + 1) * head_value_channels,
:
]
context = key @ value.transpose(1, 2)
attended_value = (
context.transpose(1, 2) @ query
).reshape(n, head_value_channels, h, w)
attended_values.append(attended_value)
aggregated_values = torch.cat(attended_values, dim=1)
reprojected_value = self.reprojection(aggregated_values)
attention = reprojected_value + input_
return attention
if __name__ == '__main__':
x = torch.randn(4, 64, 128, 128).cuda()
model = EfficientAttention(64).cuda()
out = model(x)
print(out.shape)
本文只是对论文中的即插即用模块做了整合,对论文中的一些地方难免有遗漏之处,如果想对这些模块有更详细的了解,还是要去读一下原论文,肯定会有更多收获。