import torch
import torch.nn as nn
class NCM(nn.Module):
def __init__(self, in_ch, out_ch):
super(NCM, self).__init__()
self.con1x1 = nn.Sequential(nn.Conv2d(in_ch, out_ch, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True))
self.con1_1 = nn.Sequential(nn.Conv2d(out_ch, in_ch, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
nn.BatchNorm2d(in_ch),
nn.ReLU(inplace=True))
self.SE = nn.AdaptiveAvgPool2d(1)
mid_ch = in_ch // 16
self.shared_MLP = nn.Sequential(
nn.Linear(in_ch, mid_ch),
nn.ReLU(),
nn.Linear(mid_ch, in_ch)
)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x1 = self.con1x1(x)
b, c, w, h = x1.shape
x2 = self.con1x1(x)
x3 = self.con1x1(x)
x4 = self.shared_MLP(self.SE(x).view(x.size(0), -1)).unsqueeze(2).unsqueeze(3)
x4 = self.sigmoid(x4)
x4 = self.con1x1(x4)
mul = torch.matmul(x2.view(c, -1).permute(1, 0), x3.view(c, -1))
mul = torch.softmax(mul, dim=1)
resh = torch.matmul(x1.view(c, -1), mul)
# print(resh.shape)
resh = resh.view(b, c, w, h)
# print(resh.shape)
final = resh * x4
final = self.con1_1(final)
final = x + final
return final
if __name__ == '__main__':
model = NCM(in_ch=512, out_ch=256)
x = torch.randn(16, 512, 24, 24)
y = model(x)
print(y.shape)
通道注意力与自注意力的结合的结构图
![](https://i-blog.csdnimg.cn/blog_migrate/495f8c27e1c61a67077c6b1961b144f9.png)