out += residual
return F.relu(out)
第二个残差模块
第二个残差模块用于实现ResNet50、ResNet101、ResNet152模型,SENet模块嵌入到第三个卷积后面。
class Bottleneck(nn.Module):
def init(self, in_places, places, stride=1, downsampling=False, expansion=4):
super(Bottleneck, self).init()
self.expansion = expansion
self.downsampling = downsampling
self.bottleneck = nn.Sequential(
nn.Conv2d(in_channels=in_places, out_channels=places, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(places),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=places, out_channels=places, kernel_size=3, stride=stride, padding=1, bias=False),
nn.BatchNorm2d(places),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=places, out_channels=places * self.expansion, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(places * self.expansion),
)
self.se = SELayer(places * self.expansion, 16)
if self.downsampling:
self.downsample = nn.Sequential(
nn.Conv2d(in_channels=in_places, out_channels=places * self.expansion, kernel_size=1, stride=stride,
bias=False),
nn.BatchNorm2d(places * self.expansion)
)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
residual = x
out = self.bottleneck(x)
out = self.se(out)
if self.downsampling:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
SEResNet18、SEResNet34模型的完整代码
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torchsummary import summary
class SELayer(nn.Module):
def init(self, channel, reduction=16):
super(SELayer, self).init()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.ex