神经网络注意力机制图片处理及代码(一)

        本文主要用于记录笔者深度学习中注意力机制的学习过程,如有不准确之处,欢迎各路大神指出!谢谢!

SENet

原理:使网络关注它最需要关注的通道。

1.进行自适应平均池化

2.进行两次全连接

3.将样本值映射到0到1之间,获得了输入特征层每一个通道的权值

class se_block(nn.Module):
    def __init__(self, channel, ratio=16):
        super(se_block, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
                nn.Linear(channel, channel // ratio, bias=False),
                nn.ReLU(inplace=True),
                nn.Linear(channel // ratio, channel, bias=False),
                nn.Sigmoid()
        )
  • nn.AdaptiveAvgPool2d()函数用于自适应平均池化,指定输出固定尺寸(H,W)。
  • nn.Sequential()视为多个模块封装成的单个模块,当用forward()方法接收输入之后,按照内部模块的顺序自动依次计算并输出结果。
  • nn.Linear()定义神经网络中的线性层,参数1表示输入的神经元个数,参数2表示输出神经元个数,参数3表示此线性层是否包含偏置。例:
from torch import nn
import torch

model = nn.Linear(2, 1) # 输入特征数为2,输出特征数为1
input = torch.Tensor([1, 2]) # 给一个样本,该样本有2个特征(这两个特征的值分别为1和2)
output = model(input)
output
tensor([-1.4167], grad_fn=<AddBackward0>)
  •  nn.ReLU()是一种激活函数,其内部的inplace参数若设为True,它会把输出直接覆盖到输入中,这样可以节省内存。
  • nn.Sigmoid()亦是一种激活函数,用于将样本值映射到0到1之间。

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y
  • .size()函数主要用于统计矩阵某一维上的元素个数。例:
    >>> a = np.array([[1,2,3],[4,5,6]])
    >>> np.size(a)
    6   #axis的值没有设定,返回矩阵的元素个数
    >>> np.size(a,1)
    3   #axis = 0,返回该二维矩阵的行数
    >>> np.size(a,0)
    2   #axis = 1,返回该二维矩阵的列数

  • self.avg_pool在第一段代码中已提到,用于自适应平均池化。

  • .view()函数用于重新定义矩阵形状,例:

    import torch
    v1 = torch.range(1, 16) 
    v2 = v1.view(4, 4)  
    

    其中v1为1*16大小的张量,包含16个元素,v2为4*4大小的张量,同样包含16个元素。

  • 注:矩阵形状改变前后元素个数要相同,不然会报错。

  • return x * y

    在获得输入特征层每一个通道的权值后,我们将这个权值乘上原输入特征层即可。

参考文章:

(12条消息) 神经网络学习小记录64——Pytorch 图像处理中注意力机制的解析与代码详解_Bubbliiiing的博客-CSDN博客

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值