前言
注意力机制是非常重要的改进手段,yolov8是目标检测目前最先进的技术,本文将详细展示如何在yolov8中添加注意力机制。
2. 如何改进
2.1 注意力机制代码
本文以ShuffleAttention为例:
在ultralytics\nn下创建ShuffleAttention.py,以下是代码。
import numpy as np
import torch
from torch import nn
from torch.nn import init
from torch.nn.parameter import Parameter
class ShuffleAttention(nn.Module):
def __init__(self, channel=512, reduction=16, G=8):
super().__init__()
self.G = G
self.channel = channel
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.gn = nn.GroupNorm(channel // (2 * G), channel // (2 * G))
self.cweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))
self.cbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))
self.sweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))
self.sbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))
self.sigmoid = nn.Sigmoid()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)
@staticmethod
def channel_shuffle(x, groups):
b, c, h, w = x.shape
x = x.reshape(b, groups, -1, h, w)
x = x.permute(0, 2, 1, 3, 4)
# flatten
x = x.reshape(b, -1, h, w)
return x
def forward(self, x):
b, c, h, w = x.size()
# group into subfeatures
x = x.view(b * self.G, -1, h, w) # bs*G,c//G,h,w
# channel_split
x_0, x_1 = x.chunk(2, dim=1) # bs*G,c//(2*G),h,w
# channel attention
x_channel = self.avg_pool(x_0) # bs*G,c//(2*G),1,1
x_channel = self.cweight * x_channel + self.cbias # bs*G,c//(2*G),1,1
x_channel = x_0 * self.sigmoid(x_channel)
# spatial attention
x_spatial = self.gn(x_1) # bs*G,c//(2*G),h,w
x_spatial = self.sweight * x_spatial + self.sbias # bs*G,c//(2*G),h,w
x_spatial = x_1 * self.sigmoid(x_spatial) # bs*G,c//(2*G),h,w
# concatenate along channel axis
out = torch.cat([x_channel, x_spatial], dim=1) # bs*G,c//G,h,w
out = out.contiguous().view(b, -1, h, w)
# channel shuffle
out = self.channel_shuffle(out, 2)
return out
2.2 修改task.py
打开task.py
导入刚刚添加的包:
点击Ctrl+f搜索elif m in到elif m is nn.BatchNorm2d: args = [ch[f]]代码下
代码如下:
elif m in {ShuffleAttention}: # 在此处添加自己的注意力机制,命名是刚刚的类名
args = [ch[f], *args]
2.3 修改模型
找到模型的位置ultralytics/models/v8/yolov8n.yaml,将此模型拷贝一次,在模型中添加自己的注意力机制,如下代码所示。
如果用的yolov8s,则修改yolov8s.yaml
- [ -1, 1, ShuffleAttention, [ 16, 8 ] ] #16
- [ -1, 1, Conv, [ 256, 3, 2 ] ]
- [ [ -1, 12 ], 1, Concat, [ 1 ] ] # cat head P4
- [ -1, 3, C2f, [ 512 ] ] # 19 (P4/16-medium)
- [ -1, 1, ShuffleAttention, [ 16, 8 ] ] #20
- [ -1, 1, Conv, [ 512, 3, 2 ] ]
- [ [ -1, 9 ], 1, Concat, [ 1 ] ] # cat head P5 22
- [ -1, 3, C2f, [ 1024 ] ] # 23 (P5/32-large)
注意力机制源于对人类视觉的研究。 在认知科学中,由于信息处理的瓶颈,人类会选择性地关注所有信息的一部分,同时忽略其他可见的信息。 上述机制通常被称为注意力机制。 人类视网膜不同的部位具有不同程度的信息处理能力,即敏锐度,只有视网膜中央凹部位具有最强的敏锐度。
简而言之,注意力机制源于自然界人类视觉的研究。人类的视觉会天然地进行一个抉择,就是选择性地关注所有信息的一个部分,同事就会忽略其他可见的信息。就属于是合理的利用有限的信息处理资源。