ResNet 的核心思想是引入一个所谓的「恒等快捷连接」(identity shortcut connection),直接跳过一个或多个层,如下图所示:
ImageNet的一个更深层次的残差函数F。
左图:一个积木块,BasicBlock,用于ResNet-34。右图:ResNet-50/101/152的bottleneck构建块。
BasicBlock
expansion是残差结构中输出维度是输入维度的多少倍,BasicBlock没有升维,所以expansion = 1
残差结构是在求和之后才经过ReLU层
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
bottleneck
注意Res18、Res34用的是BasicBlock,其余用的是Bottleneck。使用Bottleneck的目的为降低通道维的数量,提高速度。可简化为“降-卷-升”,一般expansion = 4,因为Bottleneck中每个残差结构输出维度都是输入维度的4倍。
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = conv1x1(inplanes, planes)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = conv3x3(planes, planes, stride)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = conv1x1(planes, planes * self.expansion)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out