基于Pytorch的MLP实现
目标
- 使用pytorch构建MLP网络
- 训练集使用MNIST数据集
- 使用GPU加速运算
- 要求准确率能达到92%以上
- 保存模型
实现
数据集:MNIST数据集的载入
MNIST数据集是一种常用的数据集,为28*28的手写数字训练集,label使用独热码,在pytorch中,可以使用torchvision.datasets.MNIST()
和torch.utils.data.DataLoader()
来导入数据集,其中
-
torchvision.datasets.MNIST()
:用于下载,导入数据集 -
torch.utils.data.DataLoader()
:用于将数据集整理成batch的形式并转换为可迭代对象
import torch as pt
import torchvision as ptv
import numpy as np
train_set = ptv.datasets.MNIST("../../pytorch_database/mnist/train",train=True,transform=ptv.transforms.ToTensor(),download=True)
test_set = ptv.datasets.MNIST("../../pytorch_database/mnist/test",train=False,transfor