1.残差网络
本文为用带残差块的CNN网络实现MNIST数据集手写数字的识别。
关于残差网络,知乎上有篇文章讲的不错,供参考:详解残差网络
残差网络比起LeNet等简单的神经网络,不同之初在于,多了一个连接线。
左边为基础的CNN结构,右边为带残差的网络结构
残差块是目前网络模型中,一个跟经典、很基础的结构,像DenseNet就是基于残差块来提出的,一个新的网络模型。
2.MNIST数据集
参考笔者的上篇博客:CNN实现MNIST数据集手写数字识别
3.模型结构
Residual Block:残差块
其结构为:
对x做两次卷积后与 x相加,再做激活
4.代码实现(pytorch)
import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
batch_size = 64
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307),(0.3081)) #两个参数,平均值和标准差
])
train_dataset = datasets.MNIST(
root="../dataset/mnist/",
train= True,
download= True,
transform= transform
)
train_loader = DataLoader(train_dataset,
shuffle = True,
batch_size = batch_size)
test_dataset = datasets.MNIST(
root="../dataset/mnist/",
train=False,
download=True,
transform=transform
)
test_loder = DataLoader(test_dataset,
shuffle = True,
batch_size = batch_size)
class ResidualBlock(torch.nn.Module):
def __init__(self, channels):
super(ResidualBlock, self).__init__()
self.channels = channels
self.conv1 = torch.nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.conv2 = torch.nn.Conv2d(channels, channels, kernel_size=3, padding=1)
def forward(self, x):
y = F.relu(self.conv1(x))
y = self.conv2(y)
return F.relu(x + y)
'''
CLASS torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0,
dilation=1, groups=1, bias=True, padding_mode='zeros',