一、实现过程
残差网络(Residual Network)的特点是容易优化,并且能够通过增加相当的深度来提高准确率。其内部的残差块使用了跳跃连接,缓解了在深度神经网络中增加深度带来的梯度消失问题。
本文实现如图1所示的两层残差模块用于识别MNIST数据集,其中每一层均是卷积层。
残差构建模块封装成类,代码如下:
class ResidualBlock(torch.nn.Module):
def __init__(self,channels):
super(ResidualBlock,self).__init__()
self.channels = channels
self.conv1 = torch.nn.Conv2d(channels,channels,kernel_size=3,padding=1)
self.conv2 = torch.nn.Conv2d(channels,channels,kernel_size=3,padding=1)
def forward(self, x):
y = F.relu(self.conv1(x))
y = self.conv2(y)
return F.relu(x+y)
嵌入残差模块的网络模型代码如下:
class Net(torch.nn.Module):
def __init__(self):
super(Net,self).