即插即用卷积之TalkingHeadAttn

本文详细介绍了TalkingHeadAttn模块,它利用多头自注意力机制处理序列数据,通过查询-键-值投影和注意力计算提高特征表示和建模效果,适用于自然语言处理等任务。
摘要由CSDN通过智能技术生成
import torch
from torch import nn

"""
个代码实现了一个名为"TalkingHeadAttn"的自注意力模块(Self-Attention),主要用于增强神经网络在输入序列上的特征表示和建模。以下是这个自注意力模块的关键部分和特点:

多头自注意力:这个模块使用了多头自注意力机制,通过将输入数据进行不同方式的投影来构建多个注意力头。num_heads参数指定了注意力头的数量,每个头将学习捕捉输入序列中不同的特征关系。

查询-键-值(QKV)投影:模块使用线性变换(nn.Linear)将输入 x 投影到查询(Q),键(K),和值(V)的空间。这个投影操作是通过self.qkv完成的。注意,为了提高计算效率,一次性生成了三个部分的投影结果。

注意力计算:通过计算 Q 和 K 的点积,然后应用 Softmax 操作,得到了注意力矩阵,表示了输入序列中各个位置之间的关联程度。这个计算是通过 attn = q @ k.transpose(-2, -1) 和 attn = attn.softmax(dim=-1) 完成的。

多头特征整合:多头注意力的输出被整合在一起,通过乘以值(V)矩阵,并进行线性变换,将多个头的结果整合到一起。这个整合过程包括了投影 self.proj_l 和 self.proj_w 操作。

Dropout正则化:在注意力计算和投影操作之后,使用 Dropout 来进行正则化,减少过拟合风险。

输出:最终的输出是通过 self.proj 和 self.proj_drop 完成的。

总的来说,TalkingHeadAttn模块通过多头自注意力机制,能够同时考虑输入序列中不同位置之间的关系,以及不同的特征关系。这有助于提高模型在序列数据上的特征提取和建模能力,使其在自然语言处理和其他序列数据任务中表现出色。这个模块通常作为大型神经网络模型的子模块,用于处理序列数据。
"""

class TalkingHeadAttn(nn.Module):

    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()

        self.num_heads = num_heads

        head_dim = dim // num_heads

        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)

        self.proj = nn.Linear(dim, dim)

        self.proj_l = nn.Linear(num_heads, num_heads)
        self.proj_w = nn.Linear(num_heads, num_heads)

        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]

        attn = q @ k.transpose(-2, -1)

        attn = self.proj_l(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)

        attn = attn.softmax(dim=-1)

        attn = self.proj_w(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


if __name__ == '__main__':
    block = TalkingHeadAttn(dim=128)
    input = torch.rand(32, 784, 128)
    output = block(input)
    print(input.size())
    print(output.size())

在这里插入图片描述

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
深度学习即插即用涨点是指通过在现有的深度学习模型中添加一些特定的层或模块,以提升模型的性能和效果。其中涨点指的是在验证集或测试集上的性能提升。 其中,DO-Conv是指Depthwise Over-parameterized Convolutional Layer,这是一种在深度学习中用于卷积操作的技术,可以减少参数量并提高模型的效率和推理速度。它通常应用于卷积神经网络(CNN)中,用于提取图像或其他数据的特征。 DeepLab是一种用于语义图像分割的深度卷积网络,它使用了Atrous Conv(空洞卷积)技术。Atrous Conv允许网络以不同的采样率(或称为膨胀率)处理输入数据,以获得更大的感受野并捕捉更多的上下文信息,从而提高分割的准确性。 DCANet是一种用于卷积神经网络的注意力学习方法,它通过学习连接的注意力模块来改善网络的性能。这种方法可以使网络更加关注重要的特征并抑制不相关的特征,从而提高模型在各种任务上的性能。 因此,要实现深度学习即插即用的涨点,可以考虑使用DO-Conv、DeepLab和DCANet等技术,根据具体的任务和数据集的需求,选择合适的模块或层来增强模型的性能。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *3* [大盘点 | 十大即插即用的涨点神器!](https://blog.csdn.net/qq_36387683/article/details/108406297)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] - *2* [【深度学习】真正的即插即用!盘点11种CNN网络设计中精巧通用的“小”插件...](https://blog.csdn.net/fengdu78/article/details/124113703)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值