使用pytorch对MNIST数据集进行分类(多项逻辑回归)

import torch
import torchvision.datasets
import torchvision.transforms
import torch.utils.data.dataloader as loader
import matplotlib.pyplot as plt

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=torchvision.transforms.ToTensor(), download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=torchvision.transforms.ToTensor(), download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=100)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=100)
fc = torch.nn.Linear(28 * 28, 10)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(fc.parameters())

epoch_number = 10
for i in range(epoch_number):
    for _, (images, labels) in enumerate(train_loader):
        x = images.reshape(-1, 28*28)
        optimizer.zero_grad()
        preds = fc(x)
        loss = criterion(preds, labels)
        loss.backward()
        optimizer.step()
    print(f'{loss}\n')

correct = 0
total = 0
for images, labels in test_loader:
    x = images.reshape(-1, 28*28)
    preds = fc(x)
    predicted = torch.argmax(preds, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()

accuracy = correct / total
print(accuracy)

for images, labels in train_loader:
    break
fig = plt.figure()
plt.imshow(images[0, 0], cmap='gray')
plt.show()

这段代码看了一天,现在对其进行回顾和总结。

1. pytorch的输入数据结构

pytorch将输入分为dataset和dataloader两部分

训练模型好比用砖块盖房,在训练初始,所有的训练数据都要准备好,即所有的砖块都要从砖厂里拉出来。
拉砖的车满满一车来了,这一车的砖叫dataset,一车砖叫一个epoch。
训练一个epoch就是把车上所有的砖块儿都训练一遍。两个epoch就是把这一车上的所有砖块儿训练两遍。

但是盖房的时候直接从车上取不方便,需要一个小推车一趟趟的从车上取。这叫batch,一批。

我们在构建batch对象之前,需要对这一整车砖进行规划,哪些砖块第一批,哪些第二批。
规划完以后,原来的dataset对象就变成了dataloader。

可以把dataloader看成一个二维的数据结构
[[一批], [一批], [一批], [一批], [一批], [一批], [一批], [一批], [一批]..]
所以从dataloader里边取数据的时候就要用一个for循环

现在简要回顾一下各个对象的结构
train_dataset包含60000个图片,每个图片都有对应的标签。
train_loader将train_dataset分成了600份,每份100张图片,每张图片都带有标签。
进入第一个for循环是对60000张图片的第一次遍历
里面的for循环有两种形式,第一种是enumerate()函数,它会返回每一批的批号,即每一车砖的车牌号,第二个返回值是每一批所有的图片(100张)以及对应的标签。
本代码在train阶段使用的是这一种,在test阶段是另一种。直接for循环,用images和labels进行接收。
images的结构是(100, 1, 28, 28)
这里100代表一共有100张图片,1代表(我也不知道代表什么,大概是因为是黑白的图片吧,1这个位置是留给彩色图片的), 28*28是行列各28个像素点
所以images的结构是一个四维的tensor向量
在这里插入图片描述
查看某一张图片,可以用plt.imshow(images[0, 0], cmap='gray')

x = images.reshape(-1, 28*28)

将4维的images变成2维,行数对应样本所在位置,即第几张图片。每行七百多列是将每个图片在每个像素点的灰度值铺展到1维空间。所以就是一张图片占一行,100张图片占100行。

这是一批,训练完这一批,权重值会有一次调整。
直到把dataset中所有数据遍历完。这叫一个epoch。
一个epoch完事代表这个数据集过一遍了,精度不符合要求的话,可以将训练集再过一遍,这叫第二个epoch。

注意,fc(x)输出的可不是10个值哦。因为x是100行,每行七百多列。所有fc(x)输出的是100行,每行10列。
相当于对100个样本分别做了预测,每个预测是在10个待选输出上给一个概率
所以fc(x)是一个二维的数据结构。
在后面测试时,那个1就是挑出每一行中最大概率值所在的位置

preds = fc(x)

predicted = torch.argmax(preds, 1)

所以经过torch.argmax(preds, 1)以后predicted就变成了一个一维的数组,每一列代表一张图片的预测值,是一个整数。因为数组的序号都是整数,所以可以和后面的labels进行直接比对。
argmax函数可以看这一篇
enumerate函数可以看这一篇
dataset、dataloader可以看这一篇

2.其他一些点

首先,我想查看某个图片,书上的方法是plt.imshow(images[0, 0], cmap='gray')
但是出了bug
Error #15: Initializing libiomp5md.dll, but found libiomp5md.dll already initialized.
下面还有一串,在网上看了很多解决方法
两种方法都可以用,但是第二种太麻烦了,所以我直接把那个文件删了一个,一共有3个,我把最近安装的一个删掉了,跑代码没再出现过问题。
在这里插入图片描述
完美解决。
此外imshow()函数在使用的时候要加plt.show(),书上并没有。

plt.imshow(images[0, 0], cmap='gray')
plt.show()

fig = plt.figure()
plt.imshow(images[0, 0], cmap='gray')
plt.show()

这两种方法都可以,但是还是推荐第二种,因为绘图就是需要在figure对象上进行操作。

此外(predicted == labels).sum().item()的用法可以看这一篇

CrossEntropyLoss()的概念是多项逻辑回归的互熵,借鉴了极大似然函数的思想。参考肖智清老师的
《神经网络与PyTorch实战》
本文大部分代码来自这本书的第6章。

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值