残差网络(ResNet)
残差网络(Residual Network, ResNet) 是由微软研究院的 Kaiming He 等人于 2015 年提出的深度卷积神经网络架构。ResNet 在 ImageNet 大规模图像分类挑战赛中取得了突破性的成绩,成为深度学习领域的一个里程碑。
ResNet 的核心思想是 残差学习(Residual Learning),即通过引入“跳跃连接(skip connection)”或者“残差连接(residual connection)”,使得网络能够更容易地训练更深的模型,并避免深度网络中常见的 梯度消失 和 梯度爆炸 问题。
1. ResNet 的核心思想
传统的深度神经网络在层数增加时,容易遇到 梯度消失(vanishing gradient) 或 梯度爆炸(exploding gradient) 问题,导致模型训练变得困难。为了应对这个问题,ResNet 引入了 残差块(Residual Block),它的关键概念是:
- 对于每一层,网络学习的是 残差,即期望的输出与输入之间的差值,而不是直接学习目标输出。
- 残差块通过 跳跃连接(skip connection),使得输入直接“跳过”某些层,连接到更深层的输出,从而保留了输入的信息,缓解了深度网络中的梯度问题。
2. 残差块(Residual Block)
每个残差块的结构可以用以下公式表示:
y = F ( x , { W i } ) + x y = F(x, \{W_i\}) + x y=F(x,{Wi})+x
其中:
- x x x 是残差块的输入。
- F ( x , { W i } ) F(x, \{W_i\}) F(x,{Wi}) 是一组卷积层、激活函数等操作后的输出,代表了网络学习的残差。
- x x x 被直接加到 F ( x , { W i } ) F(x, \{W_i\}) F(x,{Wi}) 的输出上,形成了最终的输出 y y y。
简言之,残差块通过加法将输入 x x x 与学习到的残差 F ( x , { W i } ) F(x, \{W_i\}) F(x,{Wi}) 相加,形成最终的输出。
残差块的基本结构
一个常见的残差块结构包括:
- 卷积层:通常使用 3 × 3 3 \times 3 3×3 卷积操作。
- Batch Normalization:对卷积后的输出进行标准化。
- ReLU 激活函数:对输出进行非线性变换。
- 跳跃连接(Skip Connection):输入直接加到卷积层的输出上。
如果输入和输出的维度不匹配,通常会使用 1x1 卷积 来匹配维度,确保加法操作可行。
3. ResNet 的基本架构
ResNet 由多个 残差块(Residual Blocks) 堆叠而成,具有非常深的网络结构。根据不同的深度,ResNet 可以有多种变种,比如 ResNet-18、ResNet-34、ResNet-50、ResNet-101、ResNet-152 等。
- ResNet-18 和 ResNet-34 通常用于较浅的网络,适合较简单的任务。
- ResNet-50、ResNet-101 和 ResNet-152 是深层版本,适用于更复杂的任务。
ResNet-50 架构示例
ResNet-50 是 ResNet 系列中常见的网络,其主要包含以下几部分:
- 输入层:通常为 224 × 224 × 3 224 \times 224 \times 3 224×224×3 的图像。
- 卷积层 + 最大池化:用来进行特征提取。
- 残差模块:多个残差块组成的模块。ResNet-50 由 4 个残差模块堆叠而成,每个模块包含不同数量的残差块。
- 全局平均池化:对每个通道进行池化,输出固定大小的特征图。
- 全连接层:输出最终的分类结果。
4. ResNet 的跳跃连接
ResNet 的核心在于跳跃连接的设计。跳跃连接允许网络的输入直接传递到更深层,而不是通过每一层逐渐传递。这种设计能够缓解以下问题:
- 梯度消失和梯度爆炸问题:通过跳跃连接,梯度可以直接传播到较浅的层,这样即使是深层的网络,也能保持梯度的有效传递。
- 信息流的阻塞问题:在没有跳跃连接的传统网络中,信息可能会在多层的计算中逐渐丢失,而跳跃连接使得输入信息可以直接流向深层。
示例:跳跃连接
假设某一层的输入为 x x x,经过卷积、激活等操作后输出为 F ( x ) F(x) F(x)。在添加跳跃连接后,最终输出为:
y = F ( x ) + x y = F(x) + x y=F(x)+x
这里, x x x 作为残差被直接加到输出中,解决了传统神经网络在训练过程中容易遇到的梯度消失和信息流丢失问题。
5. ResNet 的优点
-
解决梯度消失问题:
- 由于跳跃连接允许信息直接传递,ResNet 在训练深层网络时不容易出现梯度消失的问题,可以训练非常深的网络(如 152 层)。
-
有效利用深度网络:
- 传统的深度网络可能由于训练困难,无法有效利用深度带来的优势。而 ResNet 通过残差学习有效地利用了深度网络的优势,获得了更好的表现。
-
参数共享:
- 由于每个残差块只需要学习输入与输出之间的残差,ResNet 模型能有效地减少计算量和参数数量。
-
较好的泛化能力:
- 由于通过跳跃连接保持了信息流动,ResNet 在一些任务上展现出了较强的泛化能力。
6. ResNet 的缺点
-
训练时间较长:
- 虽然 ResNet 在训练时解决了梯度消失问题,但由于其网络深度较大,训练时间相对较长。
-
更复杂的模型设计:
- 尽管跳跃连接有效,但 ResNet 的设计和调试依然相对复杂,需要更高的计算资源。
-
计算资源要求较高:
- 更深的网络往往要求更多的计算资源和内存,训练时的成本较高。
7. PyTorch 实现 ResNet
以下是一个简单的 PyTorch 实现 ResNet-18 的代码示例:
import torch
import torch.nn as nn
import torch.nn.functional as F
# 定义残差块
class BasicBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
# 跳跃连接,确保输入和输出的维度相同
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += self.shortcut(x) # 加入跳跃连接
out = self.relu(out)
return out
# 定义 ResNet
class ResNet(nn.Module):
def __init__(self, block, num_blocks, num_classes=1000):
super(ResNet, self).__init__()
self.in_channels = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# 创建 ResNet 的各个残差块
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=
2)
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512, num_classes)
def _make_layer(self, block, out_channels, num_blocks, stride):
layers = []
layers.append(block(self.in_channels, out_channels, stride))
self.in_channels = out_channels
for _ in range(1, num_blocks):
layers.append(block(self.in_channels, out_channels))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
# ResNet-18
def ResNet18():
return ResNet(BasicBlock, [2, 2, 2, 2])
# 实例化模型
model = ResNet18()
print(model)
8. 总结
ResNet 是一种通过引入残差学习和跳跃连接的深度卷积神经网络架构,解决了深度神经网络中的梯度消失问题,使得非常深的神经网络成为可能。ResNet 的设计理念深刻影响了深度学习的发展,并成为许多计算机视觉任务中广泛使用的基础模型之一。