AI练手系列(三)—— RNN实现mnist手写数字识别

mnist数据集作为一个经典的数据集,不少机器学习深度学习的初学者都对其下过手,它干净的数据集、训练难度也不是很高,作为初学练手,很适合了。今天我也来拜会一下这位经典大哥,看看到底感觉如何。

本次训练使用RNN的网络模型,其实其他模型也可以,但是这两天不是正好温习RNN嘛,就刚好学以致用嘛。

请套路先生上场吧:

1.导包定参

当你新建完一个python文件不管是py文件还是ipynb文件,不管会不会,先把常用的包导入进来再说,常定义几个超参数也请上来:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt 
from torch.utils.data import DataLoader

 #定义超参数
EPOCH = 10
BATCH_SIZE = 64
LR = 0.001
TIME_STEP = 28
INPUT_SIZE = 28

2.数据加载

先把数据加载进来,虽然mnist数据集可以在线下载,但可能比较慢,我更喜欢先把数据下载到本地再加载进来,这一块的处理和cifar-10那篇就很像了。可以去回顾一下:

AI练手系列(一)—— 利用Pytorch训练CIFAR-10数据集

#数据加载
train_data = datasets.MNIST (root='./mnist',train=True,transform=transforms.ToTensor(),download=True )
test_data = datasets.MNIST(root='./mnist',train=False,
                              transform=transforms.ToTensor(),download=True )
                            
print(train_data)
print(test_data)

我们把数据的一些特征打印出来看一眼:

Dataset MNIST

Number of datapoints: 60000

Root location: ./mnist

Split: Train

StandardTransform

Transform: ToTensor()

Dataset MNIST

Number of datapoints: 10000

Root location: ./mnist

Split: Test

StandardTransform

Transform: ToTensor()

可以看到训练集60000个样本,测试集10000个,现在已经都被我们转为tensor格式了。

下面该干嘛了,数据预处理哈,归一化、分批啥的都可以搞起了。

3.数据预处理

图片数据可视化

数据是图像的话我比较喜欢先搞一张出来看看:

plt.imshow(train_data.train_data[0].numpy(),cmap='gray'
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值