使用循环神经网络做手写数字识别

本文介绍了使用循环神经网络RNN,特别是LSTM,来识别手写数字的方法。通过将28x28像素的手写数字图像的每一行视为一个样本,构建有序序列,并利用PyTorch实现模型训练,最终达到较好的识别效果。
摘要由CSDN通过智能技术生成

版权声明:本文为博主原创文章,转载请注明原文出处!

写作时间:2019-03-02 21:36:12

使用循环神经网络做手写数字识别

思路分析

做图像识别的使用卷积神经网络CNN是最好的选择,但是其实我们也可以使用循环神经网络RNN做,只是大部分时候没有卷积网络效果好!下面分析一下如何使用RNN做手写数字的识别。

  1. 数据的下载我们可以直接使用PyTorch中的torchvision.datasets提供的数据接口
  2. 对于每一张图像(28$\times$28)我们可以将图像的每一行看做一个样本,然后所有行排列起来做成一个有序序列。对于这个序列,我们就可以使用RNN做识别训练了。
  3. 下面的实现中使用一个LSTM+Linear层组合实现(不要使用经典RNN,效果不好),损失函数使用CrossEntropyLoss。
  4. 在实践中设置batch_first=True可以减少一些额外的维度变换和尺寸转换的代码,推荐使用

PyTorch实现

import torch
from torch import nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms

torch.manual_seed(2019)

# 超参设置
EPOCH = 1  # 训练EPOCH次,这里为了测试方便只跑一次
BATCH_SIZE = 32
TIME_STEP
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值