视觉注意力机制——SKNet- Selective Kernel Networks

论文下载:https://arxiv.org/abs/1903.06586
代码下载:https://github.com/implus/SKNet

视觉注意力机制——SKNet- Selective Kernel Networks

通常将软注意力机制:空间域、通道域、混合域、卷积域。
(1) 空间域——将图片中的的空间信息做相应的空间变换得到相应的权重分布,从而能将关键的信息提取出来。代表作有:Spatial Attention Module。
(2) 通道域——简单的说就是给每个通道上的信号都增加一个权重,来代表该通道与关键信息的相关性,通常权重越大,二者的相关性越高。代表作有:SELayer, Channel Attention Module。
(3) 混合域——通俗的讲就是在通道和空间上共同处理,先在空间上得到权重分布,在到通道上得到权重分布。代表作有:Spatial Attention Module+ Channel Attention Module。
(4) 卷积域——是在卷积核上做处理,得到权重分布,这是一种更高级的玩法,代表作有:SKUnit

在这里插入图片描述

一、SKNet

我们知道提高卷积核的感受野,可以获得更多的信息,但是往往获得的信息是没办法做区分的,如果能在卷积核中加一些权重,来帮助我们做信息区分,那么这就形成了卷积注意力机。SKNet基于卷积核的注意力机制,即卷积核的重要性,即不同的图像能够得到具有不同重要性的卷积核。据作者说,该模块在超分辨率任务上有很大提升,并且论文中的实验也证实了在分类任务上有很好的表现。SKNet对不同图像使用的卷积核权重不同,即一种针对不同尺度的图像动态生成卷积核。整体结构如下图所示:
此图为GiantPandaCV公众号作者根据代码重画的网络图此图为借鉴某公众号网络图
1.首先特征图X 经过3x3,5x5, 7x7, 等卷积得到U1,U2,U3三个特征图,然后相加得到了U,U中融合了多个感受野的信息。然后沿着H和W维度求平均值,最终得到了关于channel的信息是一个C×1×1的一维向量,结果表示各个通道的信息的重要程度。
2.接着再用了一个线性变换,将原来的C维映射成Z维的信息,然后分别使用了三个线性变换,从Z维变为原来的C,这样完成了正对channel维度的信息提取,然后使用Softmax进行归一化,这时候每个channel对应一个分数,代表其channel的重要程度,这相当于一个mask。
3.将这三个分别得到的mask分别乘以对应的U1,U2,U3,得到A1,A2,A3。然后三个模块相加,进行信息融合,得到最终模块A, 模块A相比于最初的X经过了信息的提炼,融合了多个感受野的信息。

在这里插入图片描述

import torch
from torch import nn


class SKConv(nn.Module):
    """
    1.首先特征图X 经过3x3,5x5, 7x7, 等卷积得到U1,U2,U3三个特征图,然后相加得到了U,U中融合了多个感受野的信息。
      然后沿着H和W维度求平均值,最终得到了关于channel的信息是一个C×1×1的一维向量,结果表示各个通道的信息的重要程度。
    2.接着再用了一个线性变换,将原来的C维映射成Z维的信息,然后分别使用了三个线性变换,从Z维变为原来的C,这样完成了正对channel维度的信息提取。
      然后使用Softmax进行归一化,这时候每个channel对应一个分数,代表其channel的重要程度,这相当于一个mask。
    3.将这三个分别得到的mask分别乘以对应的U1,U2,U3,得到A1,A2,A3。
      然后三个模块相加,进行信息融合,得到最终模块A, 模块A相比于最初的X经过了信息的提炼,融合了多个感受野的信息。
    """
    def __init__(self, features, WH, M, G, r, stride=1, L=32):
        super(SKConv, self).__init__()
        d = max(int(features / r), L) # 取两个中最大的个值
        self.M = M # 有多少路径
        self.features = features
        self.convs = nn.ModuleList([])
        for i in range(M):
            self.convs.append(nn.Sequential(
                nn.Conv2d(features, features, kernel_size=3 + i * 2, stride=stride, padding=1 + i, groups=G),
                nn.BatchNorm2d(features),
                nn.ReLU(inplace=False)
            ))
        # self.gap = nn.AvgPool2d(int(WH/stride))
        self.fc = nn.Linear(features, d)
        self.fcs = nn.ModuleList([])
        for i in range(M):
            self.fcs.append(
                nn.Linear(d, features)
            )
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        for i, conv in enumerate(self.convs):
            # (0): Conv2d、(1): Conv2d、(2): Conv2d....(n-1)
            # (b, 1, h, w) -->(b, 1, 1, h, w)
            fea = conv(x).unsqueeze_(dim=1)
            if i == 0:
                # (b, 1, 1, h, w)
                feas = fea
            else:
                # (b, 2, 1, h, w)、(b, 3, 1, h, w)
                feas = torch.cat([feas, fea], dim=1)

        fea_U = torch.sum(feas, dim=1)

        # fea_s = self.gap(fea_U).squeeze_()
        fea_s = fea_U.mean(-1).mean(-1)
        fea_z = self.fc(fea_s)
        for i, fc in enumerate(self.fcs):
            vector = fc(fea_z).unsqueeze_(dim=1)
            if i == 0:
                attention_vectors = vector
            else:
                attention_vectors = torch.cat([attention_vectors, vector], dim=1)
        attention_vectors = self.softmax(attention_vectors)
        attention_vectors = attention_vectors.unsqueeze(-1).unsqueeze(-1)
        fea_v = (feas * attention_vectors).sum(dim=1)
        return fea_v



if  __name__=="__main__":
    img = torch.randn(2, 64, 512, 512)
    model = SKConv(64,0,3,1,1)
    out = model(img)
    criterion = nn.L1Loss()
    loss = criterion(out, img)
    loss.backward()
    print("out shape:{}".format(out.shape))
    print('loss value:{}'.format(loss))

在这里插入图片描述

  • 9
    点赞
  • 57
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值