SENet是由自动驾驶公司Momenta在2017年公布的一种全新的图像识别结构,它通过对特征通道间的相关性进行建模,把重要的特征进行强化来提升准确率。SENet 是2017 ILSVR竞赛的冠军。
论文:Squeeze-and-Excitation Networks
SE block的基本结构
- 给定一个输入 ,其特征通道数为C ,通过一系列卷积等一般变换后得到一个特征通道数为C的特征。
- Squeeze:顺着空间维度进行特征压缩,将每个二维的特征通道变成一个实数,这个实数某种程度上具有全局的感受野,并且输出的维度和输入的特征通道数相匹配。
- Excitation:基于特征通道间的相关性,每个特征通道生成一个权重,用来代表特征通道的重要程度。
- Reweight:将Excitation输出的权重看做每个特征通道的重要性,然后通过乘法逐通道加权到之前的特征上,完成在通道维度上的对原始特征的重标定。
代码:
import torch
import torch.nn as nn
import math
from torchvision import models
class se_block(nn.Module):
def __init__(self, channel, ratio=16):
super(se_block, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // ratio, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel // ratio, channel, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y
class Mobilenet_v2(nn.Module):
def __init__(self):
super(Mobilenet_v2, self).__init__()
model = models.mobilenet_v2(pretrained=True)
# Remove linear and pool layers (since we're not doing classification)
modules = list(model.children())[:-1]
self.resnet = nn.Sequential(*modules)
self.pool = nn.AvgPool2d(kernel_size=7)
self.fc = nn.Linear(1280, 16)
self.sigmoid = nn.Sigmoid()
self.softmax = nn.Softmax(dim=-1)
self.attention = se_block(1280) # 1280 为上层输出通道
def forward(self, images):
x = self.resnet(images) # [N, 1280, 1, 1]
x=self.attention(x) # 此处加入se—block
x = self.pool(x)
x = x.view(-1, 1280) # [N, 1280]
x = self.fc(x)
return x
if __name__=="__main__":
input = torch.rand(2, 3, 224, 224)
mode = Mobilenet_v2()
out = mode(input)
print(out.size())
小结:
1、SE网络可以通过堆叠SE模块得到。
2、SE模块也可以嵌入到现在几乎所有的网络结构中。