# -*- coding: utf-8 -*-
# @Time : 2018/12/24 11:04
# @Author : WenZhao
# @Email : 46546924@qq.com
# @File : mnistRnn-1.py
# @Software: PyCharm
'''
RNN识别mnist手写数字数据集
'''
import tensorflow as tf
import numpy as np
# 下载数据集
from tensorflow.examples.tutorials.mnist import input_data
mnist=input_data.read_data_sets("./data/MNIST_data/",one_hot=True)
learning_rate=0.001
batch_size=128
n_input=28
n_steps=28
n_hidden=128
n_classes=10
x=tf.placeholder(tf.float32,[None,n_steps,n_input])
y=tf.placeholder(tf.float32,[None,n_classes])
output,_=tf.nn.dynamic_rnn(
tf.contrib.rnn.GRUCell(n_hidden),
x,
dtype=tf.float32,
sequence_length=batch_size*[n_input]
)
index=tf.range(0,batch_size)*n_steps+(n_input-1)
flat=tf.reshape(output,[-1,int(output.get_shape()[2])])
last=tf.gather(flat,index)
num_clas
RNN识别mnist手写数字数据集
最新推荐文章于 2022-01-11 15:25:13 发布
本文介绍如何利用循环神经网络(RNN)处理MNIST数据集,实现手写数字的识别。通过训练模型,展示RNN在图像识别领域的应用。
摘要由CSDN通过智能技术生成