PyTorch 训练 MNIST 数据集(含验证集)

欢迎关注

1. 本地创建文件夹保存数据集

from pathlib import Path
import requests

pathlib库在python3.4以后是python的内置库, Python 文档给它的定义是 Object-oriented filesystem paths(面向对象的文件系统路径),基本上可以代替os.path来处理路径。

# 指定路径,如果没有,就建一个文件夹
DATA_PATH = Path(r"D:\\data666")
PATH = DATA_PATH / "minist"

PATH.mkdir(parents=True, exist_ok=True)

pathlib的mkdir接收两个参数:

  • parents:如果父目录不存在,是否创建父目录。
  • exist_ok:只有在目录不存在时创建目录,目录已存在时不会抛出异常。
URL = "http://deeplearning.net/data/mnist/"  # 下载 mnist 数据集的地址
FILENAME = 'mnist.pkl.gz'

if not (PATH / FILENAME).exists():
    content = requests.get(URL + FILENAME).content
    (PATH / FILENAME).open("wb").write(content)

2. 数据集解压

import pickle
import gzip

with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:
    ((x_train, y_train), (x_valid, y_valid), (x_test, y_test)) = pickle.load(f, encoding="latin-1")
print(x_train.shape)
print(y_train.shape)

print(x_valid.shape)
print(y_valid.shape)

print(x_test.shape)
print(y_test.shape)

print(x_train[0].shape)

运行结果:

(50000, 784)
(50000,)
(10000, 784)
(10000,)
(10000, 784)
(10000,)
(784,)

3. 网络训练

matplotlib.pyplot.imshow()显示图像的颜色问题,想要改变分类结果图的颜色,那么可以通过改变 Colormap 来实现。imshow()函数格式为:

matplotlib.pyplot.imshow(X, cmap=None)

  • X: 要绘制的图像或数组。

  • cmap: 颜色图谱(colormap), 默认绘制为RGB(A)颜色空间。例如:matplotlib.pyplot.imshow(img, cmap=jet)

其它可选的颜色图谱请参照下面的链接:http://www.cnblogs.com/denny402/p/5122594.html

主线程序:

from matplotlib import pyplot
import numpy as np
import torch
from torch.utils.data import DataLoader
%matplotlib inline
pyplot.imshow(x_train[0].reshape((28, 28)), cmap="gray")  # 打印灰度图

运行结果:

在这里插入图片描述

主线程序:

pyplot.imshow(x_train[0].reshape((28, 28)))  # 默认瑞利图

运行结果:

在这里插入图片描述

主线程序:

print(type(x_train))
# 批量转换 tensor
x_train, x_test, x_valid, y_valid = map(
    torch.tensor, (x_train, x_test, x_valid, y_valid)
)
print(type(x_train))

运行结果:

<class 'numpy.ndarray'>
<class 'torch.Tensor'>

主线程序:

class Mnist_Logistic(torch.nn.Module):
    def __init__(self):
        super(Mnist_Logistic, self).__init__()
        self.lin = torch.nn.Linear(784, 10)

    def forward(self, xb):
        return self.lin(xb)
def get_model():
    model = Mnist_Logistic()
    return model, torch.optim.SGD(model.parameters(), lr = 0.1)
train_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(dataset=train_ds, batch_size=256, shuffle=True)
valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(dataset=valid_ds, batch_size=512)
loss_func = torch.nn.CrossEntropyLoss()
model, opt = get_model()

for epoch in range(10):
    
    model.train()  # 训练前加
    for xb, yb in train_dl:
        pred = model(xb)
        loss = loss_func(pred, yb)

        loss.backward()
        opt.step()
        opt.zero_grad()
    
    # 加入验证集
    model.eval()  # 评估模型
    with torch.no_grad():
        valid_loss = sum(loss_func(model(xb), yb) for xb, yb in valid_dl)
    print("%d  %f " % (epoch, valid_loss / len(valid_dl)))

运行结果:

0  0.455661 
1  0.382214 
2  0.352814 
3  0.335510 
4  0.325002 
5  0.316920 
6  0.309989 
7  0.305217 
8  0.302514 
9  0.298360 
  • 1
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值