YOLOv5添加改进GAM全局注意力机制,添加软阈值化记录

文章介绍了如何在YOLOv5对象检测框架的common.py文件中集成名为SGAMAttention的自定义模块,该模块包含通道注意力和空间注意力两部分。在配置文件中,详细说明了如何配置SGAMAttention层的参数,并在backbone部分的模型配置中添加该模块,以提升模型的性能。
摘要由CSDN通过智能技术生成

 在YOLOv5项目中的common.py中添加以下SGAMAttention模块。

# ---------------------------GAMAttention Begin---------------------------
class SGAMAttention(nn.Module):

    def __init__(self, c1, c2, group=True, rate=4):
        super(SGAMAttention, self).__init__()

        self.channel_attention = nn.Sequential(
            nn.Linear(c1, int(c1 / rate)),
            nn.ReLU(inplace=True),
            nn.Linear(int(c1 / rate), c1)
        )
        self.spatial_attention = nn.Sequential(
            nn.Conv2d(c1, c1 // rate, kernel_size=7, padding=3, groups=rate) if group else nn.Conv2d(c1, int(c1 / rate),
                                                                                                     kernel_size=7,
                                                                                                     padding=3),
            nn.BatchNorm2d(int(c1 / rate)),
            nn.ReLU(inplace=True),
            nn.Conv2d(c1 // rate, c2, kernel_size=7, padding=3, groups=rate) if group else nn.Conv2d(int(c1 / rate), c2,
                                                                                                     kernel_size=7,
                                                                                                     padding=3),
            nn.BatchNorm2d(c2)
        )
        self.shrinkage = Shrinkage(c2, gap_size=(1, 1))

    def forward(self, x):
        b, c, h, w = x.shape
        x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)
        x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)
        x_channel_att = x_att_permute.permute(0, 3, 1, 2)  # b,c,h,w
        x = x * x_channel_att
        # c = torch.randn[1,128,32,32]
        # x = self.model(x)
        # x = x.view(128, 1, 1)
        x = self.shrinkage(x)

        # print(x.size())

        x_spatial_att = self.spatial_attention(x).sigmoid()
        x_spatial_att = channel_shuffle(x_spatial_att, 4)  # last shuffle
        out = x * x_spatial_att
        return out


class Shrinkage(nn.Module):
    def __init__(self, channel, gap_size):
        super(Shrinkage, self).__init__()
        self.gap = nn.AdaptiveAvgPool2d(gap_size)
        # self.fc = nn.Sequential(
        #     nn.Linear(channel, channel),
        #     nn.BatchNorm1d(channel),
        #     nn.ReLU(inplace=True),
        #     nn.Linear(channel, channel),
        #     nn.Sigmoid(),
        # )
        self.fc1 = nn.Sequential(
            nn.Linear(channel, channel),
        )

        self.BN = nn.BatchNorm1d(channel)
        self.fc2 = nn.Sequential(
            nn.ReLU(inplace=True),
            nn.Linear(channel, channel),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x_raw = x
        x = torch.abs(x)
        x_abs = x
        x = self.gap(x)
        # print(x.size())
        x = torch.flatten(x, 1)

        # average = torch.mean(x, dim=1, keepdim=True)
        average = x
        # print(x.size())
        x = self.fc1(x)
        try:
            x = self.BN(x)
        except ValueError as e:
            pass

        x = self.fc2(x)
        x = torch.mul(average, x)
        x = x.unsqueeze(2).unsqueeze(2)
        # soft thresholding
        sub = x_abs - x
        zeros = sub - sub
        n_sub = torch.max(sub, zeros)
        x = torch.mul(torch.sign(x_raw), n_sub)
        return x



# ---------------------------SGAMAttention End---------------------------

 配置文件写法:

# 1024换成上一层的通道数
[-1, 1, GAMAttention, [1024, True, 4]],  

在yolo.py中添加SGAMAttention

 elif m in [SGAMAttention]:
            c1, c2 = ch[f], args[0]
            if c2 != no:
                c2 = make_divisible(c2 * gw, 8)
            args = [c1, c2, *args[1:]]

示例,将其加在backbone里的模型配置文件

# YOLOv5 🚀 by Ultralytics, GPL-3.0 license

# Parameters
nc: 4  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
anchors:
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
   [-1, 3, C3, [128]],
   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
   [-1, 6, C3, [256]],
   [-1, 1, SGAMAttention, [256, True, 4]],
   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
   [-1, 9, C3, [512]],
   [-1, 1, SGAMAttention, [512, True, 4]],
   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
   [-1, 3, C3, [1024]],
   [-1, 1, SPPF, [1024, 5]],  # 9
   [-1, 1, SGAMAttention, [1024, True, 4]],
  ]

# YOLOv5 v6.0 head
head:
  [[-1, 1, Conv, [512, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 8], 1, Concat, [1]],  # cat backbone P4
   [-1, 3, C3, [512, False]],  # 13

   [-1, 1, Conv, [256, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
   [-1, 3, C3, [256, False]],  # 17 (P3/8-small)

   [-1, 1, Conv, [256, 3, 2]],
   [[-1, 17], 1, Concat, [1]],  # cat head P4
   [-1, 3, C3, [512, False]],  # 20 (P4/16-medium)

   [-1, 1, Conv, [512, 3, 2]],
   [[-1, 13], 1, Concat, [1]],  # cat head P5
   [-1, 3, C3, [1024, False]],  # 23 (P5/32-large)

   [[20, 23, 26], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
  ]

  • 0
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值