YOLOv9改进策略【注意力机制篇】| 添加SE、CBAM、ECA、CA、Swin Transformer等注意力和多头注意力机制


前言

这篇文章带来一个经典注意力模块的汇总,虽然有些模块已经发布很久了,但后续的注意力模块也都是在此基础之上进行改进的,对于初学者来说还是有必要去学习了解一下,以加深对模块,模型的理解。

一、为什么要引入注意力机制?

来源:注意力机制的设计灵感来源于人类视觉系统。当我们在观察外界事物时,会自动将注意力集中在重要或感兴趣的区域,而忽略无关信息。计算机视觉中的注意力机制就是在试图模拟这一过程,以提高模型的感知和理解能力。

问题:随着图像数据量的增加,模型需要处理的信息量也随之增大。传统的卷积神经网络在处理大量数据时可能会遇到信息过载的问题,导致性能下降。注意力机制通过有选择地关注重要信息,帮助模型在海量数据中筛选出关键内容,从而提高检测精度。

好处

  • 注意力机制能够赋予输入数据的不同部分以不同的权重,使模型更加关注重要的特征信息。
  • 通过生成热力图,显示模型在做出决策时关注的具体区域,有助于更好地理解模型的决策过程,增强模型的可解释性。
  • 注意力机制使模型在处理不同数据集和任务时能够更灵活地调整其关注点。有助于提升模型的泛化能力,使其在面对新数据集或新任务时仍能保持较高的性能水平。

除了能够提升性能外,其最主要的还是其即插即用的特性,无论模块放在什么地方,都可以运行查看训练效果,更方便炼丹成功~

二、SE

2.1 SE的原理

通道注意力模块关注于网络中每个通道的重要性,通过为每个通道分配不同的权重,使得网络能够更加关注那些对任务更为关键的通道特征,从而提高模型的性能。其中主要涉及SqueezeExcitation两个操作。

  • Squeeze操作:通过全局平均池化将每个通道的特征图压缩为一个实数。
  • Excitation操作:利用两个全连接层(先降维后升维)和一个ReLU激活函数来学习通道间的依赖关系,并通过sigmoid函数生成权重向量。
  • Scale操作:将学习到的通道权重与原始特征图进行逐通道相乘,实现特征的重标定。
    在这里插入图片描述

论文:https://arxiv.org/abs/1709.01507
源码:https://github.com/hujie-frank/SENet

2.2 SE的实现代码

import torch.nn as nn


# SE
class SE(nn.Module):
    def __init__(self, c1, ratio=16):
        super(SE, self).__init__()
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.l1 = nn.Linear(c1, c1 // ratio, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.l2 = nn.Linear(c1 // ratio, c1, bias=False)
        self.sig = nn.Sigmoid()

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avgpool(x).view(b, c)
        y = self.l1(y)
        y = self.relu(y)
        y = self.l2(y)
        y = self.sig(y)
        y = y.view(b, c, 1, 1)
        return x * y.expand_as(x)

注意❗:在7.2小节中的yolo.py文件中需要声明的模块名称为:SE


三、CBAM

3.1 CBAM的原理

CBAM注意力模块通道注意力模块空间注意力模块两部分组成。它通过顺序地应用通道注意力和空间注意力,使得网络能够自适应地关注到输入特征图中最重要的通道和空间位置,从而提高模型的表征能力

在这里插入图片描述

  • 通道注意力模块(CAM)
    此部分的操作步骤与SE通道注意力模块的步骤一致。

  • 空间注意力模块(SAM)

    • 特征提取:在通道注意力模块处理后的特征图上,分别进行基于通道维度的最大池化平均池化操作,以生成两个新的特征图。
    • 特征融合:将两个池化后的特征图在通道维度上进行拼接(concatenate),然后通过一个卷积层进行特征融合,生成空间注意力权重图。
    • 特征增强:将空间注意力权重图与原始特征图进行逐元素相乘,实现特征的增强,使得模型能够关注到更重要的空间位置信息。

在这里插入图片描述

论文:https://arxiv.org/abs/1807.06521
源码:https://github.com/Jongchan/attention-module

3.2 CBAM的实现代码

import torch
import torch.nn as nn


class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.f1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        self.relu = nn.ReLU()
        self.f2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        avg_out = self.f2(self.relu(self.f1(self.avg_pool(x))))
        max_out = self.f2(self.relu(self.f1(self.max_pool(x))))
        out = self.sigmoid(avg_out + max_out)
        return out
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        assert kernel_size in (3, 7), "kernel size must be 3 or 7"
        padding = 3 if kernel_size == 7 else 1
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        # 1*h*w
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        # 2*h*w
        x = self.conv(x)
        # 1*h*w
        return self.sigmoid(x)


class CBAM(nn.Module):
    def __init__(self, c1, c2, ratio=16, kernel_size=7):
        super(CBAM, self).__init__()
        self.channel_attention = ChannelAttention(c1, ratio)
        self.spatial_attention = SpatialAttention(kernel_size)

    def forward(self, x):
        out = self.channel_attention(x) * x
        # c*h*w
        # c*h*w * 1*h*w
        out = self.spatial_attention(out) * out
        return out

注意❗:在7.2小节中的yolo.py文件中需要声明的模块名称为:CBAM


四、ECA

4.1 ECA的原理

ECA注意力模块的核心思想是在不增加过多计算成本和参数的情况下,通过引入一种有效的通道注意力机制,来增强网络对关键特征的关注能力。它避免了通道注意力机制中可能存在的降维操作带来的性能损失,通过一种自适应的跨通道交互策略来实现通道权重的生成。步骤如下:

  • 特征压缩:这里和SE注意力中的Squeeze操作一致,省略啦。
  • 特征学习ECA使用一维卷积来代替SE注意力中的全连接层,来学习通道间的依赖关系。这里的一维卷积核大小是自适应的,与通道维度成正比,以确保不同通道数的特征图都能有效地进行跨通道交互。通过一维卷积,ECA能够直接捕获局部跨通道交互信息,而无需进行复杂的降维和升维操作。
  • 特征重标定:这里也和SE注意力中的Scale操作一致。

在这里插入图片描述

论文:https://arxiv.org/abs/1910.03151
源码:https://github.com/BangguWu/ECANet

4.2 ECA的实现代码

import torch
import torch.nn as nn


class ECA(nn.Module):
    def __init__(self, c1, c2, k_size=3):
        super(ECA, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(
            1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
        y = self.sigmoid(y)
        return x * y.expand_as(x)

注意❗:在7.2小节中的yolo.py文件中需要声明的模块名称为:ECA


五、CA

5.1 CA的原理

CA注意力模块的核心思想是将位置信息嵌入到通道注意力中,以更精确地捕捉到图像中的空间分布特征。与普通的通道注意力机制不同的是,CA不仅关注通道间的依赖关系,还通过引入坐标信息来增强模型对空间细节的敏感度。步骤如下:

  • 特征分解:输入特征图通常具有C(通道数)、H(高度)、W(宽度)三个维度。CA首先通过两个并行的全局平均池化操作,分别沿垂直(高度)和水平(宽度)方向聚合输入特征,生成两个包含方向特定信息的特征图。这两个特征图分别捕捉了高度和宽度方向上的空间信息。
  • 特征编码:将两个池化后的特征图在通道维度上拼接,并通过一个1x1的二维卷积层来融合和转换特征。然后,对卷积后的特征图进行批量归一化非线性激活
  • 注意力图生成:将批量归一化和激活后的特征图分裂为两个特征图,分别对应于高度和宽度方向。接着,通过另外两个1x1的二维卷积层分别处理这两个特征图,并应用Sigmoid激活函数,生成两个注意力图。这两个注意力图分别沿宽度和高度方向对输入特征图进行重标定。
  • 特征重加权:将通过Sigmoid激活的注意力图与原始的输入特征图相乘,以重新加权原始特征。这样,重要的特征会被放大,而不重要的特征则会减弱,从而增强模型对关键信息的关注能力。

论文:https://arxiv.org/abs/2103.02907v1
源码:https://github.com/houqb/CoordAttention

5.2 CA的实现代码

import torch
import torch.nn as nn


class h_sigmoid(nn.Module):
    def __init__(self, inplace=True):
        super(h_sigmoid, self).__init__()
        self.relu = nn.ReLU6(inplace=inplace)

    def forward(self, x):
        return self.relu(x + 3) / 6


class h_swish(nn.Module):
    def __init__(self, inplace=True):
        super(h_swish, self).__init__()
        self.sigmoid = h_sigmoid(inplace=inplace)

    def forward(self, x):
        return x * self.sigmoid(x)


class CoordAtt(nn.Module):
    def __init__(self, inp, oup, reduction=32):
        super(CoordAtt, self).__init__()
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))
        mip = max(8, inp // reduction)
        self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(mip)
        self.act = h_swish()
        self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        identity = x
        n, c, h, w = x.size()
        # c*1*W
        x_h = self.pool_h(x)
        # c*H*1
        # C*1*h
        x_w = self.pool_w(x).permute(0, 1, 3, 2)
        y = torch.cat([x_h, x_w], dim=2)
        # C*1*(h+w)
        y = self.conv1(y)
        y = self.bn1(y)
        y = self.act(y)
        x_h, x_w = torch.split(y, [h, w], dim=2)
        x_w = x_w.permute(0, 1, 3, 2)
        a_h = self.conv_h(x_h).sigmoid()
        a_w = self.conv_w(x_w).sigmoid()
        out = identity * a_w * a_h
        return out
        

注意❗:在7.2小节中的yolo.py文件中需要声明的模块名称为:CoordAtt


六、Swin Transformer

6.1 Swin Transformer的原理

Swin Transformer通过分层设计结合多个等级的窗口划分来降低计算复杂度,并提出位移窗口使相邻的窗口之间进行交互,从而达到全局建模的能力。在Swin Transformer模型中最重要的是模块是窗口多头自注意力(W-MSA)移动窗口多头自注意力(SW-MSA),用于自注意力的计算。

  • 窗口多头自注意力(W-MSA)
    • 划分窗口:将特征图划分为多个固定大小的窗口。
    • 自注意力计算:在每个窗口内独立计算多头自注意力,此时计算复杂度与窗口内的小块数量成线性关系,从而降低了整体计算复杂度。
    • 输出:得到每个窗口内的自注意力特征图。
  • 移动窗口多头自注意力(SW-MSA)
    • 窗口移动:在W-MSA之后,通过移动窗口的方式改变窗口的划分,使得相邻的窗口之间能够产生交互。
    • 自注意力计算:在新的窗口划分下再次计算多头自注意力。
    • 输出:得到移动窗口后的自注意力特征图。

在这里插入图片描述

论文:https://arxiv.org/abs/2103.14030
源码:https://github.com/microsoft/Swin-Transformer

6.2 Swin Transformer的实现代码

class WindowAttention(nn.Module):

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        nn.init.normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):

        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        # print(attn.dtype, v.dtype)
        try:
            x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        except:
            #print(attn.dtype, v.dtype)
            x = (attn.half() @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class Mlp(nn.Module):

    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

def window_partition(x, window_size):

    B, H, W, C = x.shape
    assert H % window_size == 0, 'feature map h and w can not divide by window size'
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows

def window_reverse(windows, window_size, H, W):
    
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x


class SwinTransformerLayer(nn.Module):

    def __init__(self, dim, num_heads, window_size=8, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.SiLU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        # if min(self.input_resolution) <= self.window_size:
        #     # if window size is larger than input resolution, we don't partition windows
        #     self.shift_size = 0
        #     self.window_size = min(self.input_resolution)
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def create_mask(self, H, W):
        # calculate attention mask for SW-MSA
        img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
        h_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        w_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        cnt = 0
        for h in h_slices:
            for w in w_slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1

        mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

        return attn_mask

    def forward(self, x):
        # reshape x[b c h w] to x[b l c]
        _, _, H_, W_ = x.shape

        Padding = False
        if min(H_, W_) < self.window_size or H_ % self.window_size!=0 or W_ % self.window_size!=0:
            Padding = True
            # print(f'img_size {min(H_, W_)} is less than (or not divided by) window_size {self.window_size}, Padding.')
            pad_r = (self.window_size - W_ % self.window_size) % self.window_size
            pad_b = (self.window_size - H_ % self.window_size) % self.window_size
            x = F.pad(x, (0, pad_r, 0, pad_b))

        # print('2', x.shape)
        B, C, H, W = x.shape
        L = H * W
        x = x.permute(0, 2, 3, 1).contiguous().view(B, L, C)  # b, L, c

        # create mask from init to forward
        if self.shift_size > 0:
            attn_mask = self.create_mask(H, W).to(x.device)
        else:
            attn_mask = None

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # cyclic shift
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask=attn_mask)  # nW*B, window_size*window_size, C

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C

        # reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x
        x = x.view(B, H * W, C)

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        x = x.permute(0, 2, 1).contiguous().view(-1, C, H, W)  # b c h w

        if Padding:
            x = x[:, :, :H_, :W_]  # reverse padding

        return x


class SwinTransformerBlock(nn.Module):
    def __init__(self, c1, c2, num_heads, num_layers, window_size=8):
        super().__init__()
        self.conv = None
        if c1 != c2:
            self.conv = Conv(c1, c2)

        # remove input_resolution
        self.blocks = nn.Sequential(*[SwinTransformerLayer(dim=c2, num_heads=num_heads, window_size=window_size,
                                 shift_size=0 if (i % 2 == 0) else window_size // 2) for i in range(num_layers)])

    def forward(self, x):
        if self.conv is not None:
            x = self.conv(x)
        x = self.blocks(x)
        return x


class STCSPA(nn.Module):
    # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
        super(STCSPA, self).__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c1, c_, 1, 1)
        self.cv3 = Conv(2 * c_, c2, 1, 1)
        num_heads = c_ // 32
        self.m = SwinTransformerBlock(c_, c_, num_heads, n)
        #self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])

    def forward(self, x):
        y1 = self.m(self.cv1(x))
        y2 = self.cv2(x)
        return self.cv3(torch.cat((y1, y2), dim=1))


class STCSPB(nn.Module):
    # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
    def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
        super(STCSPB, self).__init__()
        c_ = int(c2)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c_, c_, 1, 1)
        self.cv3 = Conv(2 * c_, c2, 1, 1)
        num_heads = c_ // 32
        self.m = SwinTransformerBlock(c_, c_, num_heads, n)
        #self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])

    def forward(self, x):
        x1 = self.cv1(x)
        y1 = self.m(x1)
        y2 = self.cv2(x1)
        return self.cv3(torch.cat((y1, y2), dim=1))


class STCSPC(nn.Module):
    # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
        super(STCSPC, self).__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c1, c_, 1, 1)
        self.cv3 = Conv(c_, c_, 1, 1)
        self.cv4 = Conv(2 * c_, c2, 1, 1)
        num_heads = c_ // 32
        self.m = SwinTransformerBlock(c_, c_, num_heads, n)
        #self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])

    def forward(self, x):
        y1 = self.cv3(self.m(self.cv1(x)))
        y2 = self.cv2(x)
        return self.cv4(torch.cat((y1, y2), dim=1))

注意❗:在7.2小节中的yolo.py文件中需要声明的模块名称为:STCSPA, STCSPB, STCSPC,在模型中使用哪个选哪个就行。


七、添加步骤

此处在模型配置中以SE通道注意力为例,列举的其他注意力模块添加步骤与此完全一致

1. 修改common.py

此处需要修改的文件是models/common.py

common.py中定义了网络结构的通用模块,我们想要加入新的模块就只需要将模块代码放到这个文件内即可。

SE添加后如下:

在这里插入图片描述

2. 修改yolo.py

此处需要修改的文件是models/yolo.py

yolo.py用于函数调用,我们只需要将common.py中定义的新的模块命添加到parse_model函数下即可。

SE添加后如下:

在这里插入图片描述


八、yaml模型文件

在代码配置完成后,配置模型的YAML文件。

此处以models/detect/yolov9-c.yaml为例,在同目录下创建一个用于自己数据集训练的模型文件yolov9-c-se.yaml

yolov9-c.yaml中的内容复制到yolov9-c-se.yaml文件下,修改nc数量等于自己数据中目标的数量。
在骨干网络的最后一层添加SE模块,即下方代码中的第45行,只需要填入一个参数,通道数,和前一层通道数一致还需要注意的是,由于PAN+FPN的颈部模型结构存在,层之间的匹配也要记得修改,维度要匹配上

📌 放在此处的目的是让网络能够学习到更深层的语义信息,因为此时特征图尺寸小,包含全局信息。若是希望网络能够更加关注局部 信息,可尝试将注意力模块添加到网络的浅层。

📌 当然由于其即插即用的特性,加在哪里都是可以的,但是想要真的有效,还需要根据模型结构,数据集特性等多方面因素,多做实验进行验证。

# YOLOv9

# parameters
nc: 1  # number of classes
depth_multiple: 1.0  # model depth multiple
width_multiple: 1.0  # layer channel multiple
#activation: nn.LeakyReLU(0.1)
#activation: nn.ReLU()

# anchors
anchors: 3

# YOLOv9 backbone
backbone:
  [
   [-1, 1, Silence, []],  
   
   # conv down
   [-1, 1, Conv, [64, 3, 2]],  # 1-P1/2

   # conv down
   [-1, 1, Conv, [128, 3, 2]],  # 2-P2/4

   # elan-1 block
   [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]],  # 3

   # avg-conv down
   [-1, 1, ADown, [256]],  # 4-P3/8

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]],  # 5

   # avg-conv down
   [-1, 1, ADown, [512]],  # 6-P4/16

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 7

   # avg-conv down
   [-1, 1, ADown, [512]],  # 8-P5/32

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 9

   [-1, 1, SE, [512]],  # 10  # 注意力添加在此处
  ]

# YOLOv9 head
head:
  [
   # elan-spp block
   [-1, 1, SPPELAN, [512, 256]],  # 10

   # up-concat merge
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 7], 1, Concat, [1]],  # cat backbone P4

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 13

   # up-concat merge
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 5], 1, Concat, [1]],  # cat backbone P3

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [256, 256, 128, 1]],  # 16 (P3/8-small)

   # avg-conv-down merge
   [-1, 1, ADown, [256]],
   [[-1, 14], 1, Concat, [1]],  # cat head P4

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 19 (P4/16-medium)

   # avg-conv-down merge
   [-1, 1, ADown, [512]],
   [[-1, 11], 1, Concat, [1]],  # cat head P5

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 22 (P5/32-large)
   
   
   # multi-level reversible auxiliary branch
   
   # routing
   [5, 1, CBLinear, [[256]]], # 23
   [7, 1, CBLinear, [[256, 512]]], # 24
   [9, 1, CBLinear, [[256, 512, 512]]], # 25
   
   # conv down
   [0, 1, Conv, [64, 3, 2]],  # 26-P1/2

   # conv down
   [-1, 1, Conv, [128, 3, 2]],  # 27-P2/4

   # elan-1 block
   [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]],  # 28

   # avg-conv down fuse
   [-1, 1, ADown, [256]],  # 29-P3/8
   [[24, 25, 26, -1], 1, CBFuse, [[0, 0, 0]]], # 30  

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]],  # 31

   # avg-conv down fuse
   [-1, 1, ADown, [512]],  # 32-P4/16
   [[25, 26, -1], 1, CBFuse, [[1, 1]]], # 33 

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 34

   # avg-conv down fuse
   [-1, 1, ADown, [512]],  # 35-P5/32
   [[26, -1], 1, CBFuse, [[2]]], # 36

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 37
   
   
   
   # detection head

   # detect
   [[32, 35, 38, 17, 20, 23], 1, DualDDetect, [nc]],  # DualDDetect(A3, A4, A5, P3, P4, P5)
  ]

九、成功运行结果

打印网络模型可以看到SE模块已经加入到模型中,并可以进行训练了。

其他模块如:CBAM、ECA、CA、Swin Transformer这些模块和SE的添加步骤完全一致。

并且对于这里未提到的注意力模块的添加步骤也是一样的,只要加入模块代码,并将其添加到模型中即可。

	                 from  n    params  module                                  arguments                     
	  0                -1  1         0  models.common.Silence                   []                            
	  1                -1  1      1856  models.common.Conv                      [3, 64, 3, 2]                 
	  2                -1  1     73984  models.common.Conv                      [64, 128, 3, 2]               
	  3                -1  1    212864  models.common.RepNCSPELAN4              [128, 256, 128, 64, 1]        
	  4                -1  1    164352  models.common.ADown                     [256, 256]                    
	  5                -1  1    847616  models.common.RepNCSPELAN4              [256, 512, 256, 128, 1]       
	  6                -1  1    656384  models.common.ADown                     [512, 512]                    
	  7                -1  1   2857472  models.common.RepNCSPELAN4              [512, 512, 512, 256, 1]       
	  8                -1  1    656384  models.common.ADown                     [512, 512]                    
	  9                -1  1   2857472  models.common.RepNCSPELAN4              [512, 512, 512, 256, 1]       
	 10                -1  1      1024  models.common.SE                        [512, 512]                    
	 11                -1  1    656896  models.common.SPPELAN                   [512, 512, 256]               
	 12                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']          
	 13           [-1, 7]  1         0  models.common.Concat                    [1]                           
	 14                -1  1   3119616  models.common.RepNCSPELAN4              [1024, 512, 512, 256, 1]      
	 15                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']          
	 16           [-1, 5]  1         0  models.common.Concat                    [1]                           
	 17                -1  1    912640  models.common.RepNCSPELAN4              [1024, 256, 256, 128, 1]      
	 18                -1  1    164352  models.common.ADown                     [256, 256]                    
	 19          [-1, 14]  1         0  models.common.Concat                    [1]                           
	 20                -1  1   2988544  models.common.RepNCSPELAN4              [768, 512, 512, 256, 1]       
	 21                -1  1    656384  models.common.ADown                     [512, 512]                    
	 22          [-1, 11]  1         0  models.common.Concat                    [1]                           
	 23                -1  1   3119616  models.common.RepNCSPELAN4              [1024, 512, 512, 256, 1]      
	 24                 5  1    131328  models.common.CBLinear                  [512, [256]]                  
	 25                 7  1    393984  models.common.CBLinear                  [512, [256, 512]]             
	 26                 9  1    656640  models.common.CBLinear                  [512, [256, 512, 512]]        
	 27                 0  1      1856  models.common.Conv                      [3, 64, 3, 2]                 
	 28                -1  1     73984  models.common.Conv                      [64, 128, 3, 2]               
	 29                -1  1    212864  models.common.RepNCSPELAN4              [128, 256, 128, 64, 1]        
	 30                -1  1    164352  models.common.ADown                     [256, 256]                    
	 31  [24, 25, 26, -1]  1         0  models.common.CBFuse                    [[0, 0, 0]]                   
	 32                -1  1    847616  models.common.RepNCSPELAN4              [256, 512, 256, 128, 1]       
	 33                -1  1    656384  models.common.ADown                     [512, 512]                    
	 34      [25, 26, -1]  1         0  models.common.CBFuse                    [[1, 1]]                      
	 35                -1  1   2857472  models.common.RepNCSPELAN4              [512, 512, 512, 256, 1]       
	 36                -1  1    656384  models.common.ADown                     [512, 512]                    
	 37          [26, -1]  1         0  models.common.CBFuse                    [[2]]                         
	 38                -1  1   2857472  models.common.RepNCSPELAN4              [512, 512, 512, 256, 1]       
	 39[32, 35, 38, 17, 20, 23]  1  21542822  models.yolo.DualDDetect                 [1, [512, 512, 512, 256, 512, 512]]
  • 20
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Limiiiing

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值