1:了解SE模块:主要包含 Squeeze 和 Excitation 两部分
Squeeze:
Excitation:
2:[PyTorch]代码
import torch.nn as nn
class SELayer(nn.Module):
def __init__(self, channel, reduction=4):
super(SELayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1) # 全局平均池化,输入BCHW -> 输出 B*C*1*1
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False), # 可以看到channel得被reduction整除,否则可能出问题
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c) # 得到B*C*1*1,然后转成B*C,才能送入到FC层中。
y = self.fc(y).view(b, c, 1, 1) # 得到B*C的向量,C个值就表示C个通道的权重。把B*C变为B*C*1*1是为了与四维的x运算。
return x * y.expand_as(x) # 先把B*C*1*1变成B*C*H*W大小,其中每个通道上的H*W个值都相等。*表示对应位置相乘
3:Block应用SE
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, in_channel, out_channel, stride=1, downsample=None,
groups=1, width_per_group=64):
super(Bottleneck, self).__init__()
width = int(out_channel * (width_per_group / 64.)) * groups
self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width,
kernel_size=1, stride=1, bias=False) # squeeze channels
self.bn1 = nn.BatchNorm2d(width)
# -----------------------------------------
self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups,
kernel_size=3, stride=stride, bias=False, padding=1)
self.bn2 = nn.BatchNorm2d(width)
#reduction指通道减小的倍数
#注意使用SE时的第一个参数---等于上一个卷积层的out_channels
self.se = SELayer(width, reduction=4)
# -----------------------------------------
self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion,
kernel_size=1, stride=1, bias=False) # unsqueeze channels
self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
def forward(self, x):
identity = x
if self.downsample is not None:
identity = self.downsample(x)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.se(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
out += identity
out = self.relu(out)
return out