YOLOv8 with Attention 注意力机制

本文来源于:YOLOv8-AM: YOLOv8 with Attention Mechanisms for Pediatric Wrist Fracture Detection
代码:github

总的结构图,可以看到注意力机制模块被加载在neck部分,在upsample、C2f之后。
在这里插入图片描述
相比yolov8的模型配置文件,根据以上结构图,在每次upsanple、C2f模块后,添加了注意力机制。
其中[-1, 1, ShuffleAttention, [512]]参数含义为 :
“-1” :使用前一层的输出作为输入;
“1” :重复一次;
“ShuffleAttention”:注意力机制模块;
“[512]” :注意力机制模块的参数,和前一层的输出通道数一致,也可以有多个参数,根据模块要求配置。
最后,修改相应的detect head的输入的层编号为[17, 21, 25]。
在这里插入图片描述
下面需要修改相应的代码,让模型能加载注意力机制模块。
在这里插入图片描述
1.如果把注意力机制模块的代码放在了“ultralytics/nn/modules/conv.py” 中,那么就要修改__init__.py ,import 相应的模块。
在这里插入图片描述
在解析yaml配置文件,构建模型的时候,需要修改“ultralytics/nn/tasks.py” ,在开头import 相应模块。
在这里插入图片描述
然后修改parse_model函数,
在这里插入图片描述
构建模型后,可以看到已经添加了ShuffleAttention模块。
在这里插入图片描述
根据作者的实验结果,ResBlock_CBAM取得了较好的效果,并且推理时间只增加了1ms。
在这里插入图片描述

class ResBlock_CBAM(nn.Module):
    def __init__(self, in_places, places, stride=1, downsampling=False, expansion=1):
        super(ResBlock_CBAM, self).__init__()
        self.expansion = expansion
        self.downsampling = downsampling

        self.bottleneck = nn.Sequential(
            nn.Conv2d(in_channels=in_places, out_channels=places, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(places),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Conv2d(in_channels=places, out_channels=places, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(places),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Conv2d(in_channels=places, out_channels=places * self.expansion, kernel_size=1, stride=1,
                      bias=False),
            nn.BatchNorm2d(places * self.expansion),
        )
        # self.cbam = CBAM(c1=places * self.expansion, c2=places * self.expansion, )
        self.cbam = CBAM(c1=places * self.expansion)

        if self.downsampling:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels=in_places, out_channels=places * self.expansion, kernel_size=1, stride=stride,
                          bias=False),
                nn.BatchNorm2d(places * self.expansion)
            )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        residual = x
        out = self.bottleneck(x)
        out = self.cbam(out)
        if self.downsampling:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)
        return out

更多的注意力机制模块可以参考CNN中的注意力机制

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值