论文下载:https://arxiv.org/abs/1903.06586
代码下载:https://github.com/implus/SKNet
视觉注意力机制——SKNet- Selective Kernel Networks
通常将软注意力机制:空间域、通道域、混合域、卷积域。
(1) 空间域——将图片中的的空间信息做相应的空间变换得到相应的权重分布,从而能将关键的信息提取出来。代表作有:Spatial Attention Module。
(2) 通道域——简单的说就是给每个通道上的信号都增加一个权重,来代表该通道与关键信息的相关性,通常权重越大,二者的相关性越高。代表作有:SELayer, Channel Attention Module。
(3) 混合域——通俗的讲就是在通道和空间上共同处理,先在空间上得到权重分布,在到通道上得到权重分布。代表作有:Spatial Attention Module+ Channel Attention Module。
(4) 卷积域——是在卷积核上做处理,得到权重分布,这是一种更高级的玩法,代表作有:SKUnit
一、SKNet
我们知道提高卷积核的感受野,可以获得更多的信息,但是往往获得的信息是没办法做区分的,如果能在卷积核中加一些权重,来帮助我们做信息区分,那么这就形成了卷积注意力机。SKNet基于卷积核的注意力机制,即卷积核的重要性,即不同的图像能够得到具有不同重要性的卷积核。据作者说,该模块在超分辨率任务上有很大提升,并且论文中的实验也证实了在分类任务上有很好的表现。SKNet对不同图像使用的卷积核权重不同,即一种针对不同尺度的图像动态生成卷积核。整体结构如下图所示:
此图为借鉴某公众号网络图
1.首先特征图X 经过3x3,5x5, 7x7, 等卷积得到U1,U2,U3三个特征图,然后相加得到了U,U中融合了多个感受野的信息。然后沿着H和W维度求平均值,最终得到了关于channel的信息是一个C×1×1的一维向量,结果表示各个通道的信息的重要程度。
2.接着再用了一个线性变换,将原来的C维映射成Z维的信息,然后分别使用了三个线性变换,从Z维变为原来的C,这样完成了正对channel维度的信息提取,然后使用Softmax进行归一化,这时候每个channel对应一个分数,代表其channel的重要程度,这相当于一个mask。
3.将这三个分别得到的mask分别乘以对应的U1,U2,U3,得到A1,A2,A3。然后三个模块相加,进行信息融合,得到最终模块A, 模块A相比于最初的X经过了信息的提炼,融合了多个感受野的信息。
import torch
from torch import nn
class SKConv(nn.Module):
"""
1.首先特征图X 经过3x3,5x5, 7x7, 等卷积得到U1,U2,U3三个特征图,然后相加得到了U,U中融合了多个感受野的信息。
然后沿着H和W维度求平均值,最终得到了关于channel的信息是一个C×1×1的一维向量,结果表示各个通道的信息的重要程度。
2.接着再用了一个线性变换,将原来的C维映射成Z维的信息,然后分别使用了三个线性变换,从Z维变为原来的C,这样完成了正对channel维度的信息提取。
然后使用Softmax进行归一化,这时候每个channel对应一个分数,代表其channel的重要程度,这相当于一个mask。
3.将这三个分别得到的mask分别乘以对应的U1,U2,U3,得到A1,A2,A3。
然后三个模块相加,进行信息融合,得到最终模块A, 模块A相比于最初的X经过了信息的提炼,融合了多个感受野的信息。
"""
def __init__(self, features, WH, M, G, r, stride=1, L=32):
super(SKConv, self).__init__()
d = max(int(features / r), L) # 取两个中最大的个值
self.M = M # 有多少路径
self.features = features
self.convs = nn.ModuleList([])
for i in range(M):
self.convs.append(nn.Sequential(
nn.Conv2d(features, features, kernel_size=3 + i * 2, stride=stride, padding=1 + i, groups=G),
nn.BatchNorm2d(features),
nn.ReLU(inplace=False)
))
# self.gap = nn.AvgPool2d(int(WH/stride))
self.fc = nn.Linear(features, d)
self.fcs = nn.ModuleList([])
for i in range(M):
self.fcs.append(
nn.Linear(d, features)
)
self.softmax = nn.Softmax(dim=1)
def forward(self, x):
for i, conv in enumerate(self.convs):
# (0): Conv2d、(1): Conv2d、(2): Conv2d....(n-1)
# (b, 1, h, w) -->(b, 1, 1, h, w)
fea = conv(x).unsqueeze_(dim=1)
if i == 0:
# (b, 1, 1, h, w)
feas = fea
else:
# (b, 2, 1, h, w)、(b, 3, 1, h, w)
feas = torch.cat([feas, fea], dim=1)
fea_U = torch.sum(feas, dim=1)
# fea_s = self.gap(fea_U).squeeze_()
fea_s = fea_U.mean(-1).mean(-1)
fea_z = self.fc(fea_s)
for i, fc in enumerate(self.fcs):
vector = fc(fea_z).unsqueeze_(dim=1)
if i == 0:
attention_vectors = vector
else:
attention_vectors = torch.cat([attention_vectors, vector], dim=1)
attention_vectors = self.softmax(attention_vectors)
attention_vectors = attention_vectors.unsqueeze(-1).unsqueeze(-1)
fea_v = (feas * attention_vectors).sum(dim=1)
return fea_v
if __name__=="__main__":
img = torch.randn(2, 64, 512, 512)
model = SKConv(64,0,3,1,1)
out = model(img)
criterion = nn.L1Loss()
loss = criterion(out, img)
loss.backward()
print("out shape:{}".format(out.shape))
print('loss value:{}'.format(loss))