CoTAttention模块
论文链接:https://arxiv.org/pdf/2107.12292.pdf
将CoTAttention模块添加到MMYOLO中
-
将开源代码CoTAttention.py文件复制到mmyolo/models/plugins目录下
-
导入MMYOLO用于注册模块的包: from mmyolo.registry import MODELS
-
确保 class CoTAttention 中的输入维度为in_channels(因为MMYOLO会提前传入输入维度参数,所以要保持参数名的一致)
-
利用@MODELS.register_module()将“class CoTAttention(nn.Module)”注册:
-
修改mmyolo/models/plugins/__init__.py文件
-
在终端运行:
python setup.py install
-
修改对应的配置文件,并且将plugins的参数“type”设置为“BiLevelRoutingAttention”,可参考【YOLO改进】主干插入注意力机制模块CBAM(基于MMYOLO)-CSDN博客
修改后的CoTAttention.py
import torch
from torch import nn
from torch.nn import functional as F
from mmyolo.registry import MODELS
@MODELS.register_module()
class CoTAttention(nn.Module):
def __init__(self, in_channels=512, kernel_size=3):
super().__init__()
self.dim = in_channels
self.kernel_size = kernel_size
self.key_embed = nn.Sequential(
nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, padding=kernel_size // 2, groups=4, bias=False),
nn.BatchNorm2d(in_channels),
nn.ReLU()
)
self.value_embed = nn.Sequential(
nn.Conv2d(in_channels, in_channels, 1, bias=False),
nn.BatchNorm2d(in_channels)
)
factor = 4
self.attention_embed = nn.Sequential(
nn.Conv2d(2 * in_channels, 2 * in_channels // factor, 1, bias=False),
nn.BatchNorm2d(2 * in_channels // factor),
nn.ReLU(),
nn.Conv2d(2 * in_channels // factor, kernel_size * kernel_size * in_channels, 1)
)
def forward(self, x):
bs, c, h, w = x.shape
k1 = self.key_embed(x) # bs,c,h,w
v = self.value_embed(x).view(bs, c, -1) # bs,c,h,w
y = torch.cat([k1, x], dim=1) # bs,2c,h,w
att = self.attention_embed(y) # bs,c*k*k,h,w
att = att.reshape(bs, c, self.kernel_size * self.kernel_size, h, w)
att = att.mean(2, keepdim=False).view(bs, c, -1) # bs,c,h*w
k2 = F.softmax(att, dim=-1) * v
k2 = k2.view(bs, c, h, w)
return k1 + k2
if __name__ == '__main__':
input = torch.randn(50, 512, 7, 7)
cot = CoTAttention(dim=512, kernel_size=3)
output = cot(input)
print(output.shape)
修改后的__init__.py
# Copyright (c) OpenMMLab. All rights reserved.
from .cbam import CBAM
from .Biformer import BiLevelRoutingAttention
from