YOLOV8修改网络结构的方法(增加坐标注意力模块)

CA(CoordAttention)主要是告诉模型更应该关注哪些位置的信息,因此经常被用于优化网络,现尝试将CA模块加入YOLOv8的网络中。

1、在ultralytics/cfg/models/v8/yolov8.yaml的基础上修改网络结构,复制yolov8.yaml并重命名为CAyolov8.yaml

2、添加CA层至backbone中,记得修改之后网络层的编号。

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  ################## CA ################
  - [ -1, 1, CA, [ 1024, 32 ] ] # 9 CA block

  - [-1, 1, SPPF, [1024, 5]]  # 10

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 11
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 13

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 16 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 13], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 19 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 9], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 22 (P5/32-large)

  - [[16, 19, 22], 1, Detect, [nc]]  # Detect(P3, P4, P5)

3、打开ultralytics/nn/modules/block.py,添加CA模块。

class CA(nn.Module):
    def __init__(self, c1, c2, reduction):
        super(CA, self).__init__()
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))

        mip = max(8, c1 // reduction)

        self.conv1 = nn.Conv2d(c1, mip, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(mip)
        self.act = h_swish()

        self.conv_h = nn.Conv2d(mip, c2, kernel_size=1, stride=1, padding=0)
        self.conv_w = nn.Conv2d(mip, c2, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        identity = x

        n, c, h, w = x.size()
        x_h = self.pool_h(x)
        x_w = self.pool_w(x).permute(0, 1, 3, 2)

        y = torch.cat([x_h, x_w], dim=2)
        y = self.conv1(y)
        y = self.bn1(y)
        y = self.act(y)

        x_h, x_w = torch.split(y, [h, w], dim=2)
        x_w = x_w.permute(0, 1, 3, 2)

        a_h = self.conv_h(x_h).sigmoid()
        a_w = self.conv_w(x_w).sigmoid()

        out = identity * a_w * a_h

        return out

4、在ultralytics/nn/modules/__init__.py、ultralytics/nn/tasks.py中importCA模块,另外,在task.py中对parse_model的部分内容进行修改:

     #添加了CA 
  if m in (Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus,
                 BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x, RepC3, CA):
            c1, c2 = ch[f], args[0]

5、调试网络能否跑通。修改task.py中的DetectionModel,将默认配置文件改为"CAyolov8.yaml":

def __init__(self, cfg='CAyolov8.yaml', ch=3, nc=None, verbose=True):  # model, input channels, number of classes

此处直接输入配置文件的文件名即可,其相对位置应为ultralytics/cfg/models/v8/CAyolov8.yaml,但是v8在读取配置文件的时候应该是直接从v8文件夹里读取的。之后将下列代码添加到task.py最后,直接运行task.py。

if __name__ == '__main__':
    net = DetectionModel()
    print(net)
    net = net.cuda()

    input_rgb = torch.Tensor(8, 3, 640, 640).cuda()

    output = net(input_rgb)
    print(len(output))
    print(output[0].shape)
    print(output[1].shape)
    print(output[2].shape)

运行结果应该为打印出来的网络结构和以下输出:

3
torch.Size([8, 144, 80, 80])
torch.Size([8, 144, 40, 40])
torch.Size([8, 144, 20, 20])

之后按正常训练流程写dataset配置文件进行训练测试即可。

另外,如果训练过程中报警告:UserWarning: adaptive_avg_pool2d_backward_cuda does not have a deterministic implementation, but you set 'torch.use_deterministic_algorithms(True, warn_only=True)'. 这说明我们使用了带有不确定性的算法,就比如我们加入的CA模块,但是v8启用了确定性算法模式,弹出相应的警告。

解决方法:找到ultralytics/engine/trainer.py文件,修改backward:

# Backward
torch.use_deterministic_algorithms(False) # 禁用确定性算法模式
self.scaler.scale(self.loss).backward()

在计算backward时将确定性算法关闭即可。

  • 6
    点赞
  • 63
    收藏
    觉得还不错? 一键收藏
  • 14
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 14
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值