MNIST手写数字识别:分类任务中的最常见的、最简单的数据集mnist手写数字识别,使用的是CPU,没有用CUDA。
网络结构:一个输入层,一个两层的隐含层,一个输出层。
通过三层线性层的嵌套,然后再线性层的末尾添加一个非线性的激活函数Relu来增加网络的非线性的表达能力。
最后一层网络的激活函数一般不用Relu,可以根据具体任务决定,这里并没有用激活函数。
代码:
mnist_train.py
import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
import torchvision
from matplotlib import pyplot as plt
from utils import plot_image, plot_curve, one_hot
batch_size = 512
# 加载数据
# 'mnist_data':加载mnist数据集,路径
# train=True:选择训练集还是测试
# download=True:如果当前文件没有mnist文件就会自动从网上去下载(pytorch官方文档)
# torchvision.transforms.ToTensor():下载好的数据一般是numpy格式,转换成Tensor
# torchvision.transforms.Normalisze((0.1307,), (0.3081,)):正则化过程,为了让数据更好的在0的附近均匀的分布(经验所得,可省略,但是网络性能会略下降)
# batch_size=batch_size:表示一次加载、处理多少张图片,>=1提高效率
# shuffle=True 加载的时候做一个随机的打散
# 加载训练集
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('mnist_data', train=True, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=True)
# 加载测试集
# shuffle=False 测试集不用打散
test_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('mnist_data/', train=False, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=False)
# 中间过程
# x, y = next(iter(train_loader))
# print(x.shape, y.shape, x.min(), x.max()) #迭代过程观察中间数据
# plot_image(x, y, 'image sample')
class Net(nn.Module):
# 初始化网络
def __init__(self):
super(Net, self).__init__()
# 新建三层,每层x*w+b
self.fc1 = nn.Linear(28 * 28, 256) # 线性层,28 * 28指输入层,256是经验所得,可以更改,大维到小维降维过程
self.fc2 = nn.Linear(256, 64) #64是经验得到,大维到小维降维过程
self.fc3 = nn.Linear(64, 10) # 10是分类结果种类,固定值
# 计算过程
def forward(self, x):
# x: [b, 1, 28, 28]
# h1 = relu(xw1+b1)
x = F.relu(self.fc1(x))
# h2 = relu(h1w2+b2)
x = F.relu(self.fc2(x))
# h3 = h2w3+b3,最后一层加不加激活函数取决于具体的任务,输出是输出概率值
x = self.fc3(x) # 分类问题一般是softmax + mean squre error均方差 ,简单起见直接使用softmax
return x
# 完成一个实例化
net = Net()
# [w1, b1, w2, b2, w3, b3]
# net.parameters(): 会帮我们拿到权值 [w1, b1, w2, b2, w3, b3]
# momentum: 动量,帮助更好的优化的一个策略
# lr: 学习率,经验调整
# SGD:梯度下降的优化器
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
train_loss = []
# 对整个数据集迭代三遍
for epoch in range(3):
# 对整个数据集迭代一次,即对一个batch迭代一次,一次batch 就512张图片
for batch_idx, (x, y) in enumerate(train_loader):
# x: [b, 1, 28, 28], y: [512]
# [b, 1, 28, 28] => [b, 784]
x = x.view(x.size(0), 28 * 28)
# => [b, 10]
out = net(x) # 经过了class Net(nn.Module)
# [b, 10]
y_onehot = one_hot(y)
# loss = mse(out, y_onehot) mse 是 均方差
loss = F.mse_loss(out, y_onehot) # 1:计算out与y_onehot之间的均方差,得到loss
optimizer.zero_grad() # 先对梯度进行清零
loss.backward() # 2:梯度计算过程,计算梯度
# w' = w - lr*grad
optimizer.step() # 3:更新权值
train_loss.append(loss.item())
if batch_idx % 10 == 0: # 每隔10个batch打印一下
print(epoch, batch_idx, loss.item()) # 第几个大循环(一共3个), 第多少批次eg:10 20 30 ..., loss显示
plot_curve(train_loss)
# we get optimal [w1, b1, w2, b2, w3, b3]
# 下面用测试集来进行测试
total_correct = 0
for x, y in test_loader:
x = x.view(x.size(0), 28 * 28)
out = net(x)
# out: [b, 10] => pred: [b]
pred = out.argmax(dim=1)
correct = pred.eq(y).sum().float().item()
total_correct += correct
total_num = len(test_loader.dataset)
acc = total_correct / total_num
print('test acc:', acc)
x, y = next(iter(test_loader)) # 取一个batch,查看预测结果
out = net(x.view(x.size(0), 28 * 28))
pred = out.argmax(dim=1) # 取得[b, 10]的10个值的最大值所在位置的索引
plot_image(x, pred, 'test')
辅助文件:utils.py,绘制一些变量过程以及结果
import torch
from matplotlib import pyplot as plt
# 训练过程中loss function下降曲线的绘制
def plot_curve(data):
fig = plt.figure()
plt.plot(range(len(data)), data, color='blue')
plt.legend(['value'], loc='upper right')
plt.xlabel('step')
plt.ylabel('value')
plt.show()
# 绘制图片和图片识别结果以及图片真实的值
def plot_image(img, label, name):
fig = plt.figure()
for i in range(6):
plt.subplot(2, 3, i + 1)
plt.tight_layout()
plt.imshow(img[i][0] * 0.3081 + 0.1307, cmap='gray', interpolation='none')
plt.title("{}: {}".format(name, label[i].item()))
plt.xticks([])
plt.yticks([])
plt.show()
# 通过scatter_函数实现one_hot编码
def one_hot(label, depth=10):
out = torch.zeros(label.size(0), depth)
idx = torch.LongTensor(label).view(-1, 1)
out.scatter_(dim=1, index=idx, value=1)
return out