PyTorch实现LeNet网络
一. LeNet 网络介绍
我们将介绍LeNet,它是最早发布的卷积神经网络之一,因其在计算机视觉任务中的高效性能而受到广泛关注。 这个模型是由AT&T贝尔实验室的研究员Yann LeCun在1989年提出的(并以其命名),目的是识别图像 :cite:LeCun.Bottou.Bengio.ea.1998中的手写数字。 当时,Yann LeCun发表了第一篇在这里插入代码片
通过反向传播成功训练卷积神经网络的研究,这项工作代表了十多年来神经网络研究开发的成果。
LeNet是最经典的模型之一,主要分为两个部分:
- 卷积编码器:由两个卷积块组成
- 全连接层密集块:由三个全连接层组成
1. 网络构架
如下图,每个卷积块基本单元由一个卷积层、一个sigmoid激活函数和平均汇聚层组成。网络层次:
- 卷积层:输入输出通道 1,6;核 5*5;填充4;
- sigmoid 激活函数:
- 平均汇聚层:核 2*2;步幅 2;
- 卷积层:输入输出通道 6,16;核 5*5;
- sigmoid 激活函数:
- 平均汇聚层:核 2*2;步幅 2;
- Flatten 层:
- 全连接层:weight 400*120;
- sigmoid 激活函数:
- 全连接层:weight 120*84:
- sigmoid 激活函数:
- 全连接层:weight 84*10;
2. PyTorch模型构造
使用nn模块逐层定义每一层网络。
import torch
from torch import nn
# 网络构造
net = nn.Sequential(
nn.Conv2d(1, 6, kernel_size=5, padding=2),
nn.Sigmoid(),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Conv2d(6, 16, kernel_size=5),
nn.Sigmoid(),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Flatten(),
nn.Linear(16*5*5, 120),
nn.Sigmoid(),
nn.Linear(120, 84),
nn.Sigmoid(),
nn.Linear(84, 10)
)
3. 训练数据下载
# 数据下载
batch_size = 256
def load_FashionMNIST(batch_size):
# 将图片转换为tensor
trans = transforms.Compose([transforms.ToTensor()])
# 下载数据
fashion_minist_train = torchvision.datasets.FashionMNIST(root='../data',
train=True,
transform=trans,
download=False)
fashion_minist_test = torchvision.datasets.FashionMNIST(root='../data',
train=False,
transform=trans,
download=False)
# 将数据封装成下载器
train_iter = torch.utils.data.DataLoader(fashion_minist_train,
batch_size,
shuffle=True,
num_workers=4)
test_iter = torch.utils.data.DataLoader(fashion_minist_test,
batch_size,
shuffle=False,
num_workers=4)
load_FashionMNIST(batch_size)
4. 定义辅助函数
# 定义一个累加器
class Accumulator():
"""累加器"""
def __init__(self, num):
self.num = num
self.data = [0 for x in range(num)]
def add(self, *args):
if len(args) != self.num:
raise TypeError("parameter num error")
for index in range(self.num):
self.data[index] += args[index]
def __len__(self):
return self.num
def __getitem__(self, index):
return self.data[index]
# 定义预测函数,返回预测正确的数量
def accuracy(y_hat, y):
if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
y_hat = torch.max(y_hat, axis=1).indices
cmp = y_hat.type(y.dtype) == y
return float(cmp.type(y.dtype).sum())
# 计算模型在数据集上的精度
def evaluate_accuracy_gpu(net, data_iter, device=None):
"""计算精度"""
if isinstance(net, nn.Module):
net.eval()
if not device:
device = next(iter(net.parameters())).device
# 预测总数,预测正确的数量
metric = Accumulator(2)
with torch.no_grad():
for X, y in data_iter:
X = X.to(device)
y = y.to(device)
metric.add(accuracy(net(X), y), y.numel())
return metric[0] / metric[1]
5. 定义模型训练函数
# 模型训练函数
def train_net(net, train_iter, test_iter, num_epochs, lr, device):
"""训练网络模型"""
# 初始化模型参数
def init_weight(m):
if type(m) == nn.Linear or type(m) == nn.Conv2d:
nn.init.xavier_uniform_(m.weight)
net.apply(init_weight)
# 将模型复制到device上(cpu或者GPU)
net.to(device)
# 定义优化器
optimizer = torch.optim.SGD(net.parameters(), lr=lr)
# 定义损失函数
loss = nn.CrossEntropyLoss()
for epoch in range(num_epochs):
# 创建累加器,元素分别表示:训练损失和、训练准确率之和、范例数
metric = Accumulator(3)
# 将网络设置成训练模式
net.train()
# 小批量梯度下降训练
for i, (X, y) in enumerate(train_iter):
# 梯度归零,减训练数据复制到device上
optimizer.zero_grad()
X, y = X.to(device), y.to(device)
y_hat = net(X)
# 计算梯度,更新网络参数
l = loss(y_hat, y)
l.backward()
optimizer.step()
# 使用累加器统计
metric.add(l*X.shape[0], accuracy(y_hat, y), X.shape[0])
#
train_l = metric[0]/metric[2]
train_acc = metric[1]/metric[2]
test_acc = evaluate_accuracy_gpu(net, test_iter)
print("train_epoch {}: test_acc {},train_acc {},train_l {}".format(epoch, test_acc, train_acc, train_l))
6. 模型训练
lr, num_epochs = 0.5, 20
train_net(net, train_iter, test_iter, num_epochs, lr, "cuda:0")
训练输出:
train_epoch 0: test_acc 0.1,train_acc 0.10226666666666667,train_l 2.317789077758789
train_epoch 1: test_acc 0.556,train_acc 0.31156666666666666,train_l 1.8352042436599731
train_epoch 2: test_acc 0.6079,train_acc 0.6118,train_l 0.9809877276420593
train_epoch 3: test_acc 0.6527,train_acc 0.6822333333333334,train_l 0.8190925717353821
train_epoch 4: test_acc 0.7138,train_acc 0.7161,train_l 0.7282840013504028
train_epoch 5: test_acc 0.7398,train_acc 0.7365166666666667,train_l 0.6744959354400635
train_epoch 6: test_acc 0.7362,train_acc 0.7523666666666666,train_l 0.6371067762374878
train_epoch 7: test_acc 0.7442,train_acc 0.7672166666666667,train_l 0.6048535108566284
train_epoch 8: test_acc 0.7838,train_acc 0.7797833333333334,train_l 0.5733680129051208
train_epoch 9: test_acc 0.7667,train_acc 0.7912,train_l 0.5487761497497559
train_epoch 10: test_acc 0.7644,train_acc 0.8009833333333334,train_l 0.5255491137504578
train_epoch 11: test_acc 0.7825,train_acc 0.8077166666666666,train_l 0.5072696208953857
train_epoch 12: test_acc 0.7565,train_acc 0.8156333333333333,train_l 0.4902157783508301
train_epoch 13: test_acc 0.8151,train_acc 0.8211333333333334,train_l 0.4788495600223541
train_epoch 14: test_acc 0.8,train_acc 0.82675,train_l 0.4645676016807556
train_epoch 15: test_acc 0.8143,train_acc 0.8309333333333333,train_l 0.4531358778476715
train_epoch 16: test_acc 0.7944,train_acc 0.8369166666666666,train_l 0.4419154226779938
train_epoch 17: test_acc 0.7803,train_acc 0.8392333333333334,train_l 0.43576905131340027
train_epoch 18: test_acc 0.8299,train_acc 0.8411166666666666,train_l 0.42694082856178284
train_epoch 19: test_acc 0.8398,train_acc 0.8459333333333333,train_l 0.41870349645614624