(即插即用模块-Attention部分) 十四、(WACV 2021) Efficient Attention 高效注意力

在这里插入图片描述

paper:Efficient Attention: Attention with Linear Complexities
Code:https://github.com/cmsflash/efficient-attention


1、Efficient Attention

点积注意力 在计算机视觉与自然语言处理领域有着广泛的应用。然而,点积注意力的计算成本与输入大小呈二次增长。这种增长阻碍了在高分辨率输入上的应用。为了弥补 点积注意力 的这一缺陷,论文提出了一种新的高效注意力机制(Efficient Attention),EA 与 点积注意力 等价,但具有更少的计算开销。此外,EA 还可以极其灵活地集成到其他网络中。

论文中首先论证到点积注意力的关键缺陷,即过高的计算成本与开销。所以本文提出了一种更为有效的注意力机制,这是一种在数学上与 点积注意力 等价的注意力,但更快、更有效。在EA中,特征向量 X 仍然通过三个线性层以形成查询Q、键 K 和值V。EA 和点积注意力的不同之处在于:点积注意力首先使用 Q 与 K 合成像素级注意力图,然后使用 V 让每个像素将值与注意力图聚合;而 EA 则是改变了乘法顺序,首先通过将 K 与 V 聚合以形成全局上下文向量,再使用 Q 对上下文向量进行加权求和获得输出。

EA 带来了对注意力机制的新诠释。在点积注意力中,选择位置 i 作为参考位置,可以收集所有位置与位置的相似性,并形成该位置的注意力图 S_i 。注意力t图 S_i 表示位置i关注输入中的每个位置j的程度。S_i 上位置 j 的值越高,意味着位置 i 会更多地关注位置 j 。在点积注意力中,每个位置 i 都有这样的注意力图 S_i。

相比之下,高效注意力不会为每个位置生成注意力图。相反,它将 K 解释为注意力映射 K_t。每个K_t是一个全局注意力图,不对应于任何特定位置。它们中的每一个都对应于整个输入的语义方面。


EA 与 点积注意力 结构图:
在这里插入图片描述

2、代码实现

import torch
from torch import nn
from torch.nn import functional as f


class EfficientAttention(nn.Module):

    def __init__(self, in_channels, key_channels=2, head_count=2, value_channels=2):
        super().__init__()
        self.in_channels = in_channels
        self.key_channels = key_channels
        self.head_count = head_count
        self.value_channels = value_channels

        self.keys = nn.Conv2d(in_channels, key_channels, 1)
        self.queries = nn.Conv2d(in_channels, key_channels, 1)
        self.values = nn.Conv2d(in_channels, value_channels, 1)
        self.reprojection = nn.Conv2d(value_channels, in_channels, 1)

    def forward(self, input_):
        n, _, h, w = input_.size()
        keys = self.keys(input_).reshape((n, self.key_channels, h * w))
        queries = self.queries(input_).reshape(n, self.key_channels, h * w)
        values = self.values(input_).reshape((n, self.value_channels, h * w))
        head_key_channels = self.key_channels // self.head_count
        head_value_channels = self.value_channels // self.head_count

        attended_values = []
        for i in range(self.head_count):
            key = f.softmax(keys[
                            :,
                            i * head_key_channels: (i + 1) * head_key_channels,
                            :
                            ], dim=2)
            query = f.softmax(queries[
                              :,
                              i * head_key_channels: (i + 1) * head_key_channels,
                              :
                              ], dim=1)
            value = values[
                    :,
                    i * head_value_channels: (i + 1) * head_value_channels,
                    :
                    ]
            context = key @ value.transpose(1, 2)
            attended_value = (
                    context.transpose(1, 2) @ query
            ).reshape(n, head_value_channels, h, w)
            attended_values.append(attended_value)

        aggregated_values = torch.cat(attended_values, dim=1)
        reprojected_value = self.reprojection(aggregated_values)
        attention = reprojected_value + input_

        return attention


if __name__ == '__main__':
    x = torch.randn(4, 64, 128, 128).cuda()
    model = EfficientAttention(64).cuda()
    out = model(x)
    print(out.shape)

本文只是对论文中的即插即用模块做了整合,对论文中的一些地方难免有遗漏之处,如果想对这些模块有更详细的了解,还是要去读一下原论文,肯定会有更多收获。

### Efficient Attention in Deep Learning Implementation and Optimization In deep learning models, attention mechanisms are crucial components that allow neural networks to focus on specific parts of input data when making predictions or classifications. However, standard attention mechanisms can be computationally expensive due to their quadratic complexity with respect to sequence length. #### Understanding EfficientAttention Mechanisms EfficientAttention techniques aim to reduce computational costs while maintaining performance by optimizing how attention scores are calculated and applied within the network architecture[^1]. These methods typically involve modifications either directly to the computation process itself or through adjustments made during training phases like initialization strategies used for weight matrices involved in calculating these attentions. For instance, initializing weights improperly could lead to issues where activations from functions such as `tanh` receive extremely high absolute value inputs resulting in near-zero gradients which would significantly slow down any subsequent optimization processes [^4]. To implement an efficient version effectively: - **Sparse Attention**: Only consider interactions between certain pairs of positions rather than every possible pair. - **Low-Rank Approximation**: Decompose full-rank representations into lower-dimensional spaces before computing similarities across elements. - **Kernel-based Methods**: Utilize kernels designed specifically for handling long-range dependencies efficiently without explicitly constructing large matrices representing pairwise relationships over entire sequences. Here's a Python code snippet demonstrating sparse attention implementation using PyTorch: ```python import torch from torch import nn class SparseSelfAttention(nn.Module): def __init__(self, d_model, num_heads=8, sparsity_factor=2): super(SparseSelfAttention, self).__init__() assert d_model % num_heads == 0 self.d_k = d_model // num_heads self.num_heads = num_heads self.sparsity_factor = sparsity_factor def forward(self, Q, K, V): batch_size = Q.size(0) # Split heads Q_reshaped = Q.view(batch_size * self.num_heads, -1, self.d_k).transpose_(1, 2) K_reshaped = K.view(batch_size * self.num_heads, -1, self.d_k).transpose_(1, 2) V_reshaped = V.view(batch_size * self.num_heads, -1, self.d_k).transpose_(1, 2) # Apply sparsity pattern here... output = ... # Compute final outputs based on modified computations considering only selected indices according to chosen strategy return output.transpose_(1, 2).contiguous().view_as(Q) ``` This example shows one way to structure a custom layer implementing some form of reduced-complexity mechanism tailored towards addressing challenges associated with scaling up traditional approaches found in areas including but not limited to natural language processing tasks involving very lengthy texts.
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

御宇w

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值