# TensorFlow基础教程：搭建循环神经网络RNN

• TensorFlow版本1.4.0
• Python版本>3.5.0

1.载入MNIST数据集

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)


2.定义参数
RNN输入是一个时间序列，MNIST数据集中图片大小为28px*28px，可以将每一行的像素看成一个序列长度，那么时间步长就是28.

batch_size = 64
n_input = 784    # 图像大小
time_steps = 28  # 时间步长
input_size = 28  # 序列长度
num_classes = 10
rnn_size = 128   # rnn隐藏层大小
lr = 0.01


3.定义网络输出

x = tf.placeholder(tf.float32, shape=[None, n_input])
y = tf.placeholder(tf.float32, shape=[None, num_classes])


4.定义网络主结构

def rnn_model(x):
# 将输入x变为[batch_size, time_steps, input_size]
x = tf.reshape(x, shape=[-1, time_steps, input_size])
# 构建rnn
rnn_cell = tf.nn.rnn_cell.BasicLSTMCell(rnn_size)
# 将输入送入rnn，得到输出与中间状态，输出shape为[batch_size, time_steps, rnn_size]
outputs, states = tf.nn.dynamic_rnn(rnn_cell, x, dtype=tf.float32)
# 获取最后一个时刻的输出，输出shape为[batch_size, rnn_size]
output = tf.transpose(outputs, [1,0,2])[-1]
# 全连接层，最终输出大小为[batch_size, num_classes]
fc_w = tf.Variable(tf.random_normal([rnn_size, num_classes]))
fc_b = tf.Variable(tf.random_normal([num_classes]))
return tf.matmul(output, fc_w) + fc_b


5.构建网络

logits = rnn_model(x)
prediction = tf.nn.softmax(logits)


6.定义损失函数与优化器

loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=logits))
train_op = optimizer.minimize(loss_op)


7.定义评价指标

correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))


8.训练网络

init = tf.global_variables_initializer()

with tf.Session() as sess:
sess.run(init)

total_batch = mnist.train.num_examples // batch_size
for epoch in range(train_epochs):
for batch in range(total_batch):

batch_x, batch_y = mnist.train.next_batch(batch_size)
sess.run(train_op, feed_dict={x:batch_x, y:batch_y})

if batch % 200 == 0:
loss, acc = sess.run([loss_op, accuracy], feed_dict={x:batch_x, y:batch_y})
print("epoch {}, batch {}, loss {:.4f}, accuracy {:.3f}".format(epoch, batch, loss, acc))

print("optimization finished")

test_acc = sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels})
print("test acc", test_acc)


github源代码
https://github.com/gamersover/tensorflow_basic_tutorial/blob/master/basic_model/rnn_mnist.py

08-29 107

07-13 3518
06-15 161
09-09 3340
06-20 952
©️2020 CSDN 皮肤主题: 大白 设计师:CSDN官方博客