上一篇博客学习了如何搭建Inception网络,这篇博客主要讲述如何利用pytorch搭建ResNets网络。
上一篇博客中遗留了一个问题,就是1*1卷积核的作用,第一个作用是减少参数,第二个作用是压缩通道数,减少计算量。
理论上,随着网络深度的加深,训练应该越来越好,但是,如果没有残差网络,深度越深意味着用优化算法越难计算,ResNets网络模型优点在于它能够训练深层次的网络模型,并且有助于解决梯度消失和梯度爆炸的问题,而且能保证良好的性能。
1、ResNets结构图
从上图中可以看出,Resnets网络在计算时,在执行最后一个步骤的激活时,加上了原先的x的值,这样的操作就是防止梯度消失。
2、导入相关库、构造数据
import torch
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
#数据增强
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,),(0.3081,))
])
#构造数据集
train_dataset = datasets.MNIST(
root='../dataset/mnist',
download=False,
train=True,
transform=transform
)
test_dataset = datasets.MNIST(
root='../dataset/mnist',
download=False,
train=False,
transform = transform
)
train_loader = DataLoader(
dataset=train_dataset,
batch_size=64,
shuffle=True
)
test_loader = DataLoader(
dataset=test_dataset,
batch_size=64,
shuffle=True
)
这些代码都是在这一系列实验中共有的部分,不在做过多的解释。