在YOLOv5项目中的common.py中添加以下SGAMAttention模块。
# ---------------------------GAMAttention Begin---------------------------
class SGAMAttention(nn.Module):
def __init__(self, c1, c2, group=True, rate=4):
super(SGAMAttention, self).__init__()
self.channel_attention = nn.Sequential(
nn.Linear(c1, int(c1 / rate)),
nn.ReLU(inplace=True),
nn.Linear(int(c1 / rate), c1)
)
self.spatial_attention = nn.Sequential(
nn.Conv2d(c1, c1 // rate, kernel_size=7, padding=3, groups=rate) if group else nn.Conv2d(c1, int(c1 / rate),
kernel_size=7,
padding=3),
nn.BatchNorm2d(int(c1 / rate)),
nn.ReLU(inplace=True),
nn.Conv2d(c1 // rate, c2, kernel_size=7, padding=3, groups=rate) if group else nn.Conv2d(int(c1 / rate), c2,
kernel_size=7,
padding=3),
nn.BatchNorm2d(c2)
)
self.shrinkage = Shrinkage(c2, gap_size=(1, 1))
def forward(self, x):
b, c, h, w = x.shape
x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)
x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)
x_channel_att = x_att_permute.permute(0, 3, 1, 2) # b,c,h,w
x = x * x_channel_att
# c = torch.randn[1,128,32,32]
# x = self.model(x)
# x = x.view(128, 1, 1)
x = self.shrinkage(x)
# print(x.size())
x_spatial_att = self.spatial_attention(x).sigmoid()
x_spatial_att = channel_shuffle(x_spatial_att, 4) # last shuffle
out = x * x_spatial_att
return out
class Shrinkage(nn.Module):
def __init__(self, channel, gap_size):
super(Shrinkage, self).__init__()
self.gap = nn.AdaptiveAvgPool2d(gap_size)
# self.fc = nn.Sequential(
# nn.Linear(channel, channel),
# nn.BatchNorm1d(channel),
# nn.ReLU(inplace=True),
# nn.Linear(channel, channel),
# nn.Sigmoid(),
# )
self.fc1 = nn.Sequential(
nn.Linear(channel, channel),
)
self.BN = nn.BatchNorm1d(channel)
self.fc2 = nn.Sequential(
nn.ReLU(inplace=True),
nn.Linear(channel, channel),
nn.Sigmoid(),
)
def forward(self, x):
x_raw = x
x = torch.abs(x)
x_abs = x
x = self.gap(x)
# print(x.size())
x = torch.flatten(x, 1)
# average = torch.mean(x, dim=1, keepdim=True)
average = x
# print(x.size())
x = self.fc1(x)
try:
x = self.BN(x)
except ValueError as e:
pass
x = self.fc2(x)
x = torch.mul(average, x)
x = x.unsqueeze(2).unsqueeze(2)
# soft thresholding
sub = x_abs - x
zeros = sub - sub
n_sub = torch.max(sub, zeros)
x = torch.mul(torch.sign(x_raw), n_sub)
return x
# ---------------------------SGAMAttention End---------------------------
配置文件写法:
# 1024换成上一层的通道数
[-1, 1, GAMAttention, [1024, True, 4]],
在yolo.py中添加SGAMAttention
elif m in [SGAMAttention]:
c1, c2 = ch[f], args[0]
if c2 != no:
c2 = make_divisible(c2 * gw, 8)
args = [c1, c2, *args[1:]]
示例,将其加在backbone里的模型配置文件
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
# Parameters
nc: 4 # number of classes
depth_multiple: 0.33 # model depth multiple
width_multiple: 0.50 # layer channel multiple
anchors:
- [10,13, 16,30, 33,23] # P3/8
- [30,61, 62,45, 59,119] # P4/16
- [116,90, 156,198, 373,326] # P5/32
# YOLOv5 v6.0 backbone
backbone:
# [from, number, module, args]
[[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
[-1, 3, C3, [128]],
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
[-1, 6, C3, [256]],
[-1, 1, SGAMAttention, [256, True, 4]],
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
[-1, 9, C3, [512]],
[-1, 1, SGAMAttention, [512, True, 4]],
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
[-1, 3, C3, [1024]],
[-1, 1, SPPF, [1024, 5]], # 9
[-1, 1, SGAMAttention, [1024, True, 4]],
]
# YOLOv5 v6.0 head
head:
[[-1, 1, Conv, [512, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 8], 1, Concat, [1]], # cat backbone P4
[-1, 3, C3, [512, False]], # 13
[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 3, C3, [256, False]], # 17 (P3/8-small)
[-1, 1, Conv, [256, 3, 2]],
[[-1, 17], 1, Concat, [1]], # cat head P4
[-1, 3, C3, [512, False]], # 20 (P4/16-medium)
[-1, 1, Conv, [512, 3, 2]],
[[-1, 13], 1, Concat, [1]], # cat head P5
[-1, 3, C3, [1024, False]], # 23 (P5/32-large)
[[20, 23, 26], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
]