tensorflow--RNN实现mnist手写数字分类

@TOC采用RNN网络实现手写数字分类问题

RNN(Recurrent Neural Network)是一类用于处理序列数据的神经网络。所以要先对mnist数据进行处理,拆分为序列化的数据输入到RNN网络中。

RNN的原理这里就不介绍了,只分享一下实现的代码。代码针对初学者,比较简单,本人也是刚学习RNN,代码是某书中的例子改动后的,不懂得可以看下注释或者百度一下。

在这里插入代码片
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('mnist/',one_hot = 'True')

#####参数设置
BATCH_SIZE = 100     ##batch的大小
TIME_STEP = 28     ## 一个LSTM中输入序列的长度 image为28行
INPUT_SIZE = 28      ## x_i的向量的长度  image由28列
LR = 0.001         ## 学习率
LSTM_UNITE = 100     ##  lstm的单元
N_CLASSE = 10       ##  输出类别
ITERATION = 8000     ##  迭代次数


## 定义placeholder 接受X,Y数据
train_x = tf.placeholder(tf.float32, [None, 784])
image = tf.reshape(train_x, [-1, TIME_STEP, INPUT_SIZE])
train_y = tf.placeholder(tf.int32, [None, 10])


##  定义RNN(LSTM)网络结构
rnn_cell = tf.contrib.rnn.BasicLSTMCell(num_units = LSTM_UNITE)
outputs,final_state = tf.nn.dynamic_rnn(
	cell = rnn_cell,      ##  选择传入的cell
	inputs = image,        ##  输入的数据
	initial_state = None,    ##   初始状态
	dtype = tf.float32,        ##  数据类型
	time_major = False
)

output = tf.layers.dense(inputs = outputs[:, -1, :], units = N_CLASSE)

##  计算LOSS
loss = tf.losses.softmax_cross_entropy(onehot_labels = train_y, logits = output)

train_op = tf.train.AdamOptimizer(LR).minimize(loss)

correct_prediction = tf.equal(tf.argmax(train_y, axis=1),tf.argmax(output, axis=1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,'float'))  #计算正确率



sess = tf.Session()
sess.run(tf.global_variables_initializer())     # 初始化计算图中的变量


for step in range(ITERATION):    # 开始训练
    x, y = mnist.train.next_batch(BATCH_SIZE)
    #x1 = tf.reshape(x, [-1, 28, 28])
    test_x, test_y = mnist.test.next_batch(5000)
    _, loss_ = sess.run([train_op, loss], {train_x: x, train_y: y})
    if step % 500 == 0:      # test(validation)
        accuracy_ = sess.run(accuracy, {train_x: test_x, train_y: test_y})
        print('train loss: %.4f' % loss_, '| test accuracy: %.2f' % accuracy_)

输出结果:
train loss: 2.3014 | test accuracy: 0.14
train loss: 0.3562 | test accuracy: 0.94
train loss: 0.0380 | test accuracy: 0.95
train loss: 0.1288 | test accuracy: 0.97
train loss: 0.0875 | test accuracy: 0.97
train loss: 0.0647 | test accuracy: 0.98
train loss: 0.1115 | test accuracy: 0.97
train loss: 0.0437 | test accuracy: 0.98
train loss: 0.0865 | test accuracy: 0.98
train loss: 0.0221 | test accuracy: 0.99
train loss: 0.0974 | test accuracy: 0.99
train loss: 0.0589 | test accuracy: 0.98
train loss: 0.0747 | test accuracy: 0.99
train loss: 0.0113 | test accuracy: 0.98
train loss: 0.0229 | test accuracy: 0.98
train loss: 0.0153 | test accuracy: 0.98

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值