文章目录
版权声明:本文为博主原创文章,转载请注明原文出处!
写作时间:2019-03-02 21:36:12
使用循环神经网络做手写数字识别
思路分析
做图像识别的使用卷积神经网络CNN是最好的选择,但是其实我们也可以使用循环神经网络RNN做,只是大部分时候没有卷积网络效果好!下面分析一下如何使用RNN做手写数字的识别。
- 数据的下载我们可以直接使用PyTorch中的
torchvision.datasets
提供的数据接口 - 对于每一张图像(28$\times$28)我们可以将图像的每一行看做一个样本,然后所有行排列起来做成一个有序序列。对于这个序列,我们就可以使用RNN做识别训练了。
- 下面的实现中使用一个LSTM+Linear层组合实现(不要使用经典RNN,效果不好),损失函数使用CrossEntropyLoss。
- 在实践中设置
batch_first=True
可以减少一些额外的维度变换和尺寸转换的代码,推荐使用
PyTorch实现
import torch
from torch import nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
torch.manual_seed(2019)
# 超参设置
EPOCH = 1 # 训练EPOCH次,这里为了测试方便只跑一次
BATCH_SIZE = 32
TIME_STEP