Pytorch框架实现LeNet-5
1.背景与目标
在这一讲中,我们将讲解近年来流行的深度学习编程工具Pytorch
的使用方法。最近几年Pytorch
工具使用份额日益增长,目前已经成为学术界研究深度学习的第一编程工具。这一讲我们仍然以LeNet
为例来讲解Pytorch
这一编程工具。
2. 基于LeNet的Pytorch实现
2.1 主函数main()
首先,我们打开main.py
程序:
import torch
from torch.utils.data import DataLoader
import scipy.misc
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torch import optim
'''LeNet in PyTorch.'''
import torch.nn as nn
import torch.nn.functional as F
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5, padding = 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16*5*5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
out = F.relu(self.conv1(x))
out = F.max_pool2d(out, 2)
out = F.relu(self.conv2(out))
out = F.max_pool2d(out, 2)
out = out