三分钟学会使用系列(YOLOv5)|CloFormer: 注意力机制与卷积的完美融合,涨点神器!
- 原文地址:《Rethinking Local Perception in Lightweight Vision Transformer》
- 代码地址:《https://github.com/qhfan/CloFormer/tree/main/》
1. CloFormer通道注意力导读
CloFormer 引入了一种结合了注意力机制和卷积运算的模块AttnConv,能够捕捉高频的局部信息。相比于传统的卷积操作,AttnConv 使用共享权重和上下文感知权重,能够更好地处理图像中不同位置之间的关系。实验结果表明,CloFormer 在图像分类、目标检测和语义分割任务中具有优越的性能。
2. 代码
CloFormer 代码。
class AttnMap(nn.Module):
def __init__(self, dim):
super().__init__()
self.act_block = nn.Sequential(
nn.Conv2d(dim, dim, 1, 1, 0),
MemoryEfficientSwish(),
nn.Conv2d(dim, dim, 1, 1, 0)
#nn.Identity()
)
def forward(self, x):
return self.act_block(x)
class EfficientAttention(nn.Module):
def __init__(self, dim, num_heads, group_split: List[int], kernel_sizes: List[int], window_size=7,
attn_drop=0., proj_drop=0., qkv_bias=True):
super().__init__()
assert sum(group_split) == num_heads
assert len(kernel_sizes) + 1 == len(group_split)
self.dim = dim
self.num_heads = num_heads
self.dim_head = dim // num_heads
self.scalor = self.dim_head ** -0.5
self.kernel_sizes = kernel_sizes
self.window_size = window_size
self.group_split = group_split
convs = []
act_blocks = []
qkvs = []
#projs = []
for i in range(len(kernel_sizes)):
kernel_size = kernel_sizes[i]
group_head = group_split[i]
if group_head == 0:
continue
convs.append(nn.Conv2d(3*self.dim_head*group_head, 3*self.dim_head*group_head, kernel_size,
1, kernel_size//2, groups=3*self.dim_head*group_head))
act_blocks.append(AttnMap(self.dim_head*group_head))
qkvs.append(nn.Conv2d(dim, 3*group_head*self.dim_head, 1, 1, 0, bias=qkv_bias))
#projs.append(nn.Linear(group_head*self.dim_head, group_head*self.dim_head, bias=qkv_bias))
if group_split[-1] != 0:
self.global_q = nn.Conv2d(dim, group_split[-1]*self.dim_head, 1, 1, 0, bias=qkv_bias)
self.global_kv = nn.Conv2d(dim, group_split[-1]*self.dim_head*2, 1, 1, 0, bias=qkv_bias)
#self.global_proj = nn.Linear(group_split[-1]*self.dim_head, group_split[-1]*self.dim_head, bias=qkv_bias)
self.avgpool = nn.AvgPool2d(window_size, window_size) if window_size!=1 else nn.Identity()
self.convs = nn.ModuleList(convs)
self.act_blocks = nn.ModuleList(act_blocks)
self.qkvs = nn.ModuleList(qkvs)
self.proj = nn.Conv2d(dim, dim, 1, 1, 0, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj_drop = nn.Dropout(proj_drop)
def high_fre_attntion(self, x: torch.Tensor, to_qkv: nn.Module, mixer: nn.Module, attn_block: nn.Module):
'''
x: (b c h w)
'''
b, c, h, w = x.size()
qkv = to_qkv(x) #(b (3 m d) h w)
qkv = mixer(qkv).reshape(b, 3, -1, h, w).transpose(0, 1).contiguous() #(3 b (m d) h w)
q, k, v = qkv #(b (m d) h w)
attn = attn_block(q.mul(k)).mul(self.scalor)
attn = self.attn_drop(torch.tanh(attn))
res = attn.mul(v) #(b (m d) h w)
return res
def low_fre_attention(self, x : torch.Tensor, to_q: nn.Module, to_kv: nn.Module, avgpool: nn.Module):
'''
x: (b c h w)
'''
b, c, h, w = x.size()
q = to_q(x).reshape(b, -1, self.dim_head, h*w).transpose(-1, -2).contiguous() #(b m (h w) d)
kv = avgpool(x) #(b c h w)
kv = to_kv(kv).view(b, 2, -1, self.dim_head, (h*w)//(self.window_size**2)).permute(1, 0, 2, 4, 3).contiguous() #(2 b m (H W) d)
k, v = kv #(b m (H W) d)
attn = self.scalor * q @ k.transpose(-1, -2) #(b m (h w) (H W))
attn = self.attn_drop(attn.softmax(dim=-1))
res = attn @ v #(b m (h w) d)
res = res.transpose(2, 3).reshape(b, -1, h, w).contiguous()
return res
def forward(self, x: torch.Tensor):
'''
x: (b c h w)
'''
res = []
for i in range(len(self.kernel_sizes)):
if self.group_split[i] == 0:
continue
res.append(self.high_fre_attntion(x, self.qkvs[i], self.convs[i], self.act_blocks[i]))
if self.group_split[-1] != 0:
res.append(self.low_fre_attention(x, self.global_q, self.global_kv, self.avgpool))
return self.proj_drop(self.proj(torch.cat(res, dim=1)))
3. 使用教程
以YOLOv5为例,使用CloFormer 。
# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
# Parameters
nc: 80 # number of classes
depth_multiple: 0.33 # model depth multiple
width_multiple: 0.25 # layer channel multiple
anchors:
- [10,13, 16,30, 33,23] # P3/8
- [30,61, 62,45, 59,119] # P4/16
- [116,90, 156,198, 373,326] # P5/32
# YOLOv5 v6.0 backbone
backbone:
# [from, number, module, args]
[[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
[-1, 3, C3, [128]],
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
[-1, 6, C3, [256]],
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
[-1, 9, C3, [512]],
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
[-1, 3, C3, [1024]],
[ -1, 1, EfficientAttention, []],
[-1, 1, SPPF, [1024, 5]], # 10
]
# YOLOv5 v6.0 head
head:
[[-1, 1, Conv, [512, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 3, C3, [512, False]], # 14
[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 3, C3, [256, False]], # 18 (P3/8-small)
[-1, 1, Conv, [256, 3, 2]],
[[-1, 14], 1, Concat, [1]], # cat head P4
[-1, 3, C3, [512, False]], # 21 (P4/16-medium)
[-1, 1, Conv, [512, 3, 2]],
[[-1, 11], 1, Concat, [1]], # cat head P5
[-1, 3, C3, [1024, False]], # 24 (P5/32-large)
[[18, 21, 24], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
]
4. 参考
- 原文地址:《Rethinking Local Perception in Lightweight Vision Transformer》
- 代码地址:《https://github.com/qhfan/CloFormer/tree/main/》