本文接续上文介绍YOLOv10如何添加其他先进注意力机制 。超30种注意力机制模块,助力文章涨点多多🚀🚀
提示:喜欢本专栏的小伙伴,请多多点赞关注支持。本文仅供学习交流使用,创作不易,未经作者允许,不得搬运或转载!!!
注意力机制模块介绍🛩️🛩️
6、CoTAttention (Contextual Transformer Attention)🌱🌱
论文地址:https://arxiv.org/abs/2107.12292
CoTAttention于2021年提出,是一种用于视觉识别的注意力机制,它结合了卷积神经网络和Transformer模型的优势,旨在增强网络对上下文信息的捕捉和利用能力。CoTAttention通过在特征图中引入上下文信息,使得网络能够更好地理解和处理复杂的视觉任务。
CoTAttention的优点
- 增强的上下文感知能力:通过上下文建模,CoTAttention能够捕捉更丰富的上下文信息,提高网络对复杂场景的理解能力。
- 兼容性强:CoTAttention可以与现有的卷积神经网络和Transformer模型结合使用,提升其性能。
- 显著的性能提升:在视觉识别任务中,CoTAttention通过增强特征表示,显著提升了模型的分类和检测等任务的表现。
import numpy as np
import torch
from torch import flatten, nn
from torch.nn import init
from torch.nn.modules.activation import ReLU
from torch.nn.modules.batchnorm import BatchNorm2d
from torch.nn import functional as F
class CoTAttention(nn.Module):
def __init__(self, dim=512, kernel_size=3):
super().__init__()
self.dim = dim
self.kernel_size = kernel_size
self.key_embed = nn.Sequential(
nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=4, bias=False),
nn.BatchNorm2d(dim),
nn.ReLU()
)
self.value_embed = nn.Sequential(
nn.Conv2d(dim, dim, 1, bias=False),
nn.BatchNorm2d(dim)
)
factor = 4
self.attention_embed = nn.Sequential(
nn.Conv2d(2 * dim, 2 * dim // factor, 1, bias=False),
nn.BatchNorm2d(2 * dim // factor),
nn.ReLU(),
nn.Conv2d(2 * dim // factor, kernel_size * kernel_size * dim, 1)
)
def forward(self, x):
bs, c, h, w = x.shape
k1 = self.key_embed(x) # bs,c,h,w
v = self.value_embed(x).view(bs, c, -1) # bs,c,h,w
y = torch.cat([k1, x], dim=1) # bs,2c,h,w
att = self.attention_embed(y) # bs,c*k*k,h,w
att = att.reshape(bs, c, self.kernel_size * self.kernel_size, h, w)
att = att.mean(2, keepdim=False).view(bs, c, -1) # bs,c,h*w
k2 = F.softmax(att, dim=-1) * v
k2 = k2.view(bs, c, h, w)
return k1 + k2
7、MobileViTAttention🌱🌱
论文地址:https://arxiv.org/pdf/2110.02178
MobileViTAttention于2022年提出,是一种结合了卷积操作和自注意力机制的轻量化注意力模块,专为移动设备和资源受限环境设计。通过有效地融合局部和全局特征,MobileViTAttention 提升了模型的特征表示能力和在各种视觉任务中的性能。其轻量级和高效的特性使其成为移动视觉应用中的理想选择。
MobileViTAttention的优点
- 轻量级:设计时考虑了移动设备的限制,MobileViTAttention 保持了较低的计算复杂度和参数量。
- 全局上下文感知:通过自注意力机制,MobileViTAttention 能够捕捉全局上下文信息,增强特征表示能力。
- 高效特征融合:结合局部和全局特征,提升了模型在视觉任务中的表现。
from torch import nn
import torch
from einops import rearrange
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.ln = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.ln(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, mlp_dim, dropout):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, mlp_dim),
nn.SiLU(),
nn.Dropout(dropout),
nn.Linear(mlp_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads, head_dim, dropout):
super().__init__()
inner_dim = heads * head_dim
project_out = not (heads == 1 and head_dim == dim)
self.heads = heads
self.scale = head_dim ** -0.5
self.attend = nn.Softmax(dim=-1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b p h n d -> b p n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, head_dim, mlp_dim, dropout=0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads, head_dim, dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout))
]))
def forward(self, x):
out = x
for att, ffn in self.layers:
out = out + att(out)
out = out + ffn(out)
return out
class MobileViTAttention(nn.Module):
def __init__(self, in_channel=3, dim=512, kernel_size=3, patch_size=7):
super().__init__()
self.ph, self.pw = patch_size, patch_size
self.conv1 = nn.Conv2d(in_channel, in_channel, kernel_size=kernel_size, padding=kernel_size // 2)
self.conv2 = nn.Conv2d(in_channel, dim, kernel_size=1)
self.trans = Transformer(dim=dim, depth=3, heads=8, head_dim=64, mlp_dim=1024)
self.conv3 = nn.Conv2d(dim, in_channel, kernel_size=1)
self.conv4 = nn.Conv2d(2 * in_channel, in_channel, kernel_size=kernel_size, padding=kernel_size // 2)
def forward(self, x):
y = x.clone() # bs,c,h,w
## Local Representation
y = self.conv2(self.conv1(x)) # bs,dim,h,w
## Global Representation
_, _, h, w = y.shape
y = rearrange(y, 'bs dim (nh ph) (nw pw) -> bs (ph pw) (nh nw) dim', ph=self.ph, pw=self.pw) # bs,h,w,dim
y = self.trans(y)
y = rearrange(y, 'bs (ph pw) (nh nw) dim -> bs dim (nh ph) (nw pw)', ph=self.ph, pw=self.pw, nh=h // self.ph,
nw=w // self.pw) # bs,dim,h,w
## Fusion
y = self.conv3(y) # bs,dim,h,w
y = torch.cat([x, y], 1) # bs,2*dim,h,w
y = self.conv4(y) # bs,c,h,w
return y
8、SimAM(Simple Attention Module)🌱🌱
论文地址:http://proceedings.mlr.press/v139/yang21o/yang21o.pdf
SimAM (Simple Attention Module) 是一种无参数的注意力机制,旨在提升卷积神经网络的性能。它的设计思想简单且高效,通过构造一个能量函数来衡量每个神经元的重要性,并通过优化该能量函数来实现注意力机制。SimAM 的核心优势在于无需额外的参数和计算复杂度,因此可以很容易地嵌入到现有的神经网络中。
SimAM的优点
- 简洁:SimAM 不引入额外的可训练参数,保持了模型的简洁性。
- 高效:能量函数的计算和优化过程非常高效,不会显著增加模型的计算复杂度。
- 兼容性好:可以很容易地嵌入到现有的各种卷积神经网络架构中,提升它们的性能。
import torch
import torch.nn as nn
class SimAM(torch.nn.Module):
def __init__(self, e_lambda=1e-4):
super(SimAM, self).__init__()
self.activaton = nn.Sigmoid()
self.e_lambda = e_lambda
def __repr__(self):
s = self.__class__.__name__ + '('
s += ('lambda=%f)' % self.e_lambda)
return s
@staticmethod
def get_module_name():
return "simam"
def forward(self, x):
b, c, h, w = x.size()
n = w * h - 1
x_minus_mu_square = (x - x.mean(dim=[2, 3], keepdim=True)).pow(2)
y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2, 3], keepdim=True) / n + self.e_lambda)) + 0.5
return x * self.activaton(y)
9、SKNets (Selective Kernel Networks)🌱🌱
论文地址:https://arxiv.org/pdf/1903.06586
Selective Kernel Networks (SKNets) 是一种增强型卷积神经网络架构,通过引入选择性卷积核机制来提高模型的灵活性和表现力。SKNets 的核心思想是利用多个不同尺度的卷积核,并通过一个选择模块动态选择最适合当前输入的卷积核,从而在保持计算效率的同时,捕捉不同尺度的特征。
SKNets 的优点
- 多尺度特征提取:通过并行使用多种尺寸的卷积核,SKNets 能够捕捉不同尺度的特征,提高模型的表现力。
- 动态特征融合:通过全局信息和注意力机制,SKNets 能够动态选择最适合当前输入的卷积核,提高特征提取的准确性。
- 轻量高效:尽管引入了多尺度卷积核和注意力机制,SKNets 依然保持了较高的计算效率和较低的参数量。
import numpy as np
import torch
from torch import nn
from torch.nn import init
from collections import OrderedDict
class SKAttention(nn.Module):
def __init__(self, channel=512, kernels=[1, 3, 5, 7], reduction=16, group=1, L=32):
super().__init__()
self.d = max(L, channel // reduction)
self.convs = nn.ModuleList([])
for k in kernels:
self.convs.append(
nn.Sequential(OrderedDict([
('conv', nn.Conv2d(channel, channel, kernel_size=k, padding=k // 2, groups=group)),
('bn', nn.BatchNorm2d(channel)),
('relu', nn.ReLU())
]))
)
self.fc = nn.Linear(channel, self.d)
self.fcs = nn.ModuleList([])
for i in range(len(kernels)):
self.fcs.append(nn.Linear(self.d, channel))
self.softmax = nn.Softmax(dim=0)
def forward(self, x):
bs, c, _, _ = x.size()
conv_outs = []
### split
for conv in self.convs:
conv_outs.append(conv(x))
feats = torch.stack(conv_outs, 0) # k,bs,channel,h,w
### fuse
U = sum(conv_outs) # bs,c,h,w
### reduction channel
S = U.mean(-1).mean(-1) # bs,c
Z = self.fc(S) # bs,d
### calculate attention weight
weights = []
for fc in self.fcs:
weight = fc(Z)
weights.append(weight.view(bs, c, 1, 1)) # bs,channel
attention_weughts = torch.stack(weights, 0) # k,bs,channel,1,1
attention_weughts = self.softmax(attention_weughts) # k,bs,channel,1,1
### fuse
V = (attention_weughts * feats).sum(0)
return V
10、Shuffle Attention🌱🌱
论文地址:https://arxiv.org/pdf/2102.00240.pdf
Shuffle Attention 通过通道混洗、通道注意力和空间注意力的结合,有效地增强了卷积神经网络的特征表示能力。其核心思想是通过动态地选择和加权特征图的通道和空间维度的重要特征,从而在保持高效计算的同时,提高模型的表现力和泛化能力。SA 注意力机制适用于各种计算机视觉任务,如图像分类、目标检测和语义分割等。
SA 的优点
- 多尺度特征提取:通过通道混洗和注意力机制,SA 能够捕捉不同尺度的特征,提高特征表达的丰富性。
- 高效性:SA 机制仅需少量额外的计算资源和参数,能够在保证模型高效性的同时显著增强模型性能。
- 灵活性:SA 机制可以很容易地嵌入到现有的各种卷积神经网络架构中,提升它们的性能。
import numpy as np
import torch
from torch import nn
from torch.nn import init
from torch.nn.parameter import Parameter
class ShuffleAttention(nn.Module):
def __init__(self, channel=512, reduction=16, G=8):
super().__init__()
self.G = G
self.channel = channel
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.gn = nn.GroupNorm(channel // (2 * G), channel // (2 * G))
self.cweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))
self.cbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))
self.sweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))
self.sbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))
self.sigmoid = nn.Sigmoid()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)
@staticmethod
def channel_shuffle(x, groups):
b, c, h, w = x.shape
x = x.reshape(b, groups, -1, h, w)
x = x.permute(0, 2, 1, 3, 4)
# flatten
x = x.reshape(b, -1, h, w)
return x
def forward(self, x):
b, c, h, w = x.size()
# group into subfeatures
x = x.view(b * self.G, -1, h, w) # bs*G,c//G,h,w
# channel_split
x_0, x_1 = x.chunk(2, dim=1) # bs*G,c//(2*G),h,w
# channel attention
x_channel = self.avg_pool(x_0) # bs*G,c//(2*G),1,1
x_channel = self.cweight * x_channel + self.cbias # bs*G,c//(2*G),1,1
x_channel = x_0 * self.sigmoid(x_channel)
# spatial attention
x_spatial = self.gn(x_1) # bs*G,c//(2*G),h,w
x_spatial = self.sweight * x_spatial + self.sbias # bs*G,c//(2*G),h,w
x_spatial = x_1 * self.sigmoid(x_spatial) # bs*G,c//(2*G),h,w
# concatenate along channel axis
out = torch.cat([x_channel, x_spatial], dim=1) # bs*G,c//G,h,w
out = out.contiguous().view(b, -1, h, w)
# channel shuffle
out = self.channel_shuffle(out, 2)
return out
11、S2Attention🌱🌱
论文地址:https://arxiv.org/abs/2108.01072
S2Attention是一种基于空间位移的注意力机制,它在MLP架构中引入了空间信息的捕捉,通过局部特征的移动来实现更高效的特征交互;核心思想是通过对特征进行空间位移,来捕捉局部上下文信息,从而增强特征的表达能力。S2Attention的设计灵感来自于卷积操作中的位移机制,但它在MLP结构中实现了类似的功能,从而增强了模型的表达能力。
优点:
- 增强特征捕捉能力:通过空间位移操作,S2Attention能够捕捉到更多的局部上下文信息,从而提升特征的表达能力。
- 结构简单高效:相比于传统的卷积操作,S2Attention在MLP架构中实现了类似的功能,且计算效率较高。
- 适应多种任务:S2Attention可以广泛应用于图像分类、目标检测、语义分割等视觉任务,显著提升模型性能。
import numpy as np
import torch
from torch import nn
from torch.nn import init
def spatial_shift1(x):
b, w, h, c = x.size()
x[:, 1:, :, :c // 4] = x[:, :w - 1, :, :c // 4]
x[:, :w - 1, :, c // 4:c // 2] = x[:, 1:, :, c // 4:c // 2]
x[:, :, 1:, c // 2:c * 3 // 4] = x[:, :, :h - 1, c // 2:c * 3 // 4]
x[:, :, :h - 1, 3 * c // 4:] = x[:, :, 1:, 3 * c // 4:]
return x
def spatial_shift2(x):
b, w, h, c = x.size()
x[:, :, 1:, :c // 4] = x[:, :, :h - 1, :c // 4]
x[:, :, :h - 1, c // 4:c // 2] = x[:, :, 1:, c // 4:c // 2]
x[:, 1:, :, c // 2:c * 3 // 4] = x[:, :w - 1, :, c // 2:c * 3 // 4]
x[:, :w - 1, :, 3 * c // 4:] = x[:, 1:, :, 3 * c // 4:]
return x
class SplitAttention(nn.Module):
def __init__(self, channel=512, k=3):
super().__init__()
self.channel = channel
self.k = k
self.mlp1 = nn.Linear(channel, channel, bias=False)
self.gelu = nn.GELU()
self.mlp2 = nn.Linear(channel, channel * k, bias=False)
self.softmax = nn.Softmax(1)
def forward(self, x_all):
b, k, h, w, c = x_all.shape
x_all = x_all.reshape(b, k, -1, c) # bs,k,n,c
a = torch.sum(torch.sum(x_all, 1), 1) # bs,c
hat_a = self.mlp2(self.gelu(self.mlp1(a))) # bs,kc
hat_a = hat_a.reshape(b, self.k, c) # bs,k,c
bar_a = self.softmax(hat_a) # bs,k,c
attention = bar_a.unsqueeze(-2) # #bs,k,1,c
out = attention * x_all # #bs,k,n,c
out = torch.sum(out, 1).reshape(b, h, w, c)
return out
class S2Attention(nn.Module):
def __init__(self, channels=512):
super().__init__()
self.mlp1 = nn.Linear(channels, channels * 3)
self.mlp2 = nn.Linear(channels, channels)
self.split_attention = SplitAttention()
def forward(self, x):
b, c, w, h = x.size()
x = x.permute(0, 2, 3, 1)
x = self.mlp1(x)
x1 = spatial_shift1(x[:, :, :, :c])
x2 = spatial_shift2(x[:, :, :, c:c * 2])
x3 = x[:, :, :, c * 2:]
x_all = torch.stack([x1, x2, x3], 1)
a = self.split_attention(x_all)
x = self.mlp2(a)
x = x.permute(0, 3, 1, 2)
return x
12、TripletAttention🌱🌱
论文地址:https://arxiv.org/abs/2010.03045
TripletAttention旨在增强卷积神经网络(CNN)在捕捉不同尺度和方向的特征时的能力, 通过将输入特征图沿不同轴旋转来捕捉更丰富的上下文信息。
它主要包括三个步骤:
- 旋转操作:对输入特征图进行旋转,以捕捉不同方向上的信息。
- 注意力计算:对旋转后的特征图分别计算注意力权重。
- 融合:将不同方向上的注意力特征进行融合,得到最终的特征图。
import torch
import torch.nn as nn
class BasicConv(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True,
bn=True, bias=False):
super(BasicConv, self).__init__()
self.out_channels = out_planes
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias)
self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None
self.relu = nn.ReLU() if relu else None
def forward(self, x):
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
if self.relu is not None:
x = self.relu(x)
return x
class ZPool(nn.Module):
def forward(self, x):
return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)
class AttentionGate(nn.Module):
def __init__(self):
super(AttentionGate, self).__init__()
kernel_size = 7
self.compress = ZPool()
self.conv = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2, relu=False)
def forward(self, x):
x_compress = self.compress(x)
x_out = self.conv(x_compress)
scale = torch.sigmoid_(x_out)
return x * scale
class TripletAttention(nn.Module):
def __init__(self, no_spatial=False):
super(TripletAttention, self).__init__()
self.cw = AttentionGate()
self.hc = AttentionGate()
self.no_spatial = no_spatial
if not no_spatial:
self.hw = AttentionGate()
def forward(self, x):
x_perm1 = x.permute(0, 2, 1, 3).contiguous()
x_out1 = self.cw(x_perm1)
x_out11 = x_out1.permute(0, 2, 1, 3).contiguous()
x_perm2 = x.permute(0, 3, 2, 1).contiguous()
x_out2 = self.hc(x_perm2)
x_out21 = x_out2.permute(0, 3, 2, 1).contiguous()
if not self.no_spatial:
x_out = self.hw(x)
x_out = 1 / 3 * (x_out + x_out11 + x_out21)
else:
x_out = 1 / 2 * (x_out11 + x_out21)
return x_out
13、 ECA(Efficient Channel Attention)🌱🌱
论文地址: https://arxiv.org/pdf/1910.03151
在卷积神经网络中,通道注意力机制通过强调重要特征通道来提升模型性能。传统的SE模块虽然有效,但需要引入额外的全连接层和参数,增加了计算复杂度和内存占用。ECA提出了一种无参数、低计算量的替代方案,旨在保留通道注意力的优势,同时提高效率。ECA通过局部跨通道交互来实现通道注意力,不需要引入额外的全连接层。具体来说,它通过一个一维卷积操作来捕捉通道间的依赖关系。其关键点在于使用一个基于全局平均池化的轻量级注意力模块,从而显著减少计算成本。
优点:
- 低计算成本:ECA避免了全连接层和大量参数,仅使用一个轻量级的一维卷积操作,大大降低了计算复杂度和内存占用。
- 无参数设计:ECA不引入额外的参数,保持了模型的紧凑性,适合嵌入式和移动设备等计算资源有限的场景。
- 高效的通道依赖捕捉:通过局部跨通道交互,ECA能够有效捕捉通道间的重要依赖关系,提升模型性能。
import numpy as np
import torch
from torch import nn
from torch.nn import init
from collections import OrderedDict
class ECAAttention(nn.Module):
def __init__(self, kernel_size=3):
super().__init__()
self.gap = nn.AdaptiveAvgPool2d(1)
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2)
self.sigmoid = nn.Sigmoid()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)
def forward(self, x):
y = self.gap(x) # bs,c,1,1
y = y.squeeze(-1).permute(0, 2, 1) # bs,1,c
y = self.conv(y) # bs,1,c
y = self.sigmoid(y) # bs,1,c
y = y.permute(0, 2, 1).unsqueeze(-1) # bs,c,1,1
return x * y.expand_as(x)
14、ParNetAttention🌱🌱
论文地址: https://arxiv.org/abs/2110.07641
ParNetAttention是专为非深度网络设计的注意力机制,旨在提升模型性能的同时保持其计算效率和结构简单性。
优点:
- 轻量级:与传统的深度注意力机制相比,ParNetAttention计算量较低,非常适合非深度网络。
- 增强特征表达:通过计算特征图中每个位置的重要性并加权求和,ParNetAttention能够有效突出关键特征。
- 结构简单:ParNetAttention的结构设计简单,容易集成到现有的非深度网络中。
import numpy as np
import torch
from torch import nn
from torch.nn import init
class ParNetAttention(nn.Module):
def __init__(self, channel=512):
super().__init__()
self.sse = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(channel, channel, kernel_size=1),
nn.Sigmoid()
)
self.conv1x1 = nn.Sequential(
nn.Conv2d(channel, channel, kernel_size=1),
nn.BatchNorm2d(channel)
)
self.conv3x3 = nn.Sequential(
nn.Conv2d(channel, channel, kernel_size=3, padding=1),
nn.BatchNorm2d(channel)
)
self.silu = nn.SiLU()
def forward(self, x):
b, c, _, _ = x.size()
x1 = self.conv1x1(x)
x2 = self.conv3x3(x)
x3 = self.sse(x) * x
y = self.silu(x1 + x2 + x3)
return y
15、 MHSA(Multi-Head-Self-Attention)🌱🌱
论文地址:https://wuch15.github.io/paper/EMNLP2019-NRMS.pdf
Multi-Head Self-Attention 机制通过多个注意力头来并行计算输入序列中不同位置之间的相关性,从而捕捉丰富的上下文信息。每个注意力头可以独立地关注不同的特征或位置,最终将这些信息融合,形成更强大的特征表示。
Multi-Head Self-Attention 的核心思想是将输入序列通过多个独立的注意力机制进行处理,每个注意力机制称为一个“头”。每个头关注输入序列的不同方面,最终将多个头的输出进行拼接和线性变换,得到最终的特征表示。
import torch
import torch.nn as nn
class MHSA(nn.Module):
def __init__(self, n_dims, width=14, height=14, heads=4, pos_emb=False):
super(MHSA, self).__init__()
self.heads = heads
self.query = nn.Conv2d(n_dims, n_dims, kernel_size=1)
self.key = nn.Conv2d(n_dims, n_dims, kernel_size=1)
self.value = nn.Conv2d(n_dims, n_dims, kernel_size=1)
self.pos = pos_emb
if self.pos:
self.rel_h_weight = nn.Parameter(torch.randn([1, heads, (n_dims) // heads, 1, int(height)]),
requires_grad=True)
self.rel_w_weight = nn.Parameter(torch.randn([1, heads, (n_dims) // heads, int(width), 1]),
requires_grad=True)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
n_batch, C, width, height = x.size()
q = self.query(x).view(n_batch, self.heads, C // self.heads, -1)
k = self.key(x).view(n_batch, self.heads, C // self.heads, -1)
v = self.value(x).view(n_batch, self.heads, C // self.heads, -1)
content_content = torch.matmul(q.permute(0, 1, 3, 2), k) # 1,C,h*w,h*w
c1, c2, c3, c4 = content_content.size()
if self.pos:
content_position = (self.rel_h_weight + self.rel_w_weight).view(1, self.heads, C // self.heads, -1).permute(
0, 1, 3, 2) # 1,4,1024,64
content_position = torch.matmul(content_position, q) # ([1, 4, 1024, 256])
content_position = content_position if (
content_content.shape == content_position.shape) else content_position[:, :, :c3, ]
assert (content_content.shape == content_position.shape)
energy = content_content + content_position
else:
energy = content_content
attention = self.softmax(energy)
out = torch.matmul(v, attention.permute(0, 1, 3, 2)) # 1,4,256,64
out = out.view(n_batch, C, width, height)
return out
注:注意力机制的添加方式见上一篇文章,添加位置不固定,可根据自己的需求灵活调整。感谢大家的支持和关注❤❤
本文至此结束,文章持续更新中,敬请期待!!!
喜欢的本文的话,请不吝点赞+收藏,感谢大家的支持🍵🍵