在使用tensorflow对构建rnn模型的时候,有几个参数一直不能很好的理解它本身的结构,这对后续网络的修改产生了很大的问题,在网上查阅资料后对其中一些参数结构进行总结。
例子代码如下:
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
def Weight_Variable(shape):
initial = tf.truncated_normal(shape=shape, stddev=0.1)
return tf.Variable(initial)
def Bias_Variable(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial)
def RNN(x, weights, bias, n_times, n_inputs, n_hidden_units):
# inputs shape:(100,28,28)
inputs = tf.reshape(x, [-1, n_times, n_inputs])
lstm_cell = tf.contrib.rnn.BasicLSTMCell(n_hidden_units)
# todo tensorflow 删除了core_rnn_cell
# todo lstm_cell = tf.contrib.rnn.core_rnn_cell.BasicLSTMCell
# output shape:(100,28,100) finall_state为一个包含两个元素的tuple,其中每个元素的shape都为(100,100)
output, finall_state = tf.nn.dynamic_rnn(lstm_cell, inputs, dtype=tf.float32)
# predicton shape:(100,10)
prediction = tf.matmul(finall_state[1], weights) + bias
return prediction
def main():
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
batch_size = 100
n_batch = mnist.train.num_examples // batch_size
n_inputs = 28
n_times = 28
n_hidden_units = 100
n_classes = 10
x = tf.placeholder(tf.float32, [None, n_inputs * n_times])
y = tf.placeholder(tf.float32, [None, n_classes])
weights = Weight_Variable([n_hidden_units, n_classes])
bias = Bias_Variable([n_classes])
prediction = RNN(x, weights, bias, n_times, n_inputs, n_hidden_units)
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=prediction))
optimizer = tf.train.AdamOptimizer(1e-4)
train = optimizer.minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(prediction, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for epoch in range(11):
for batch in range(n_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
# todo batch_xs.shape=(batch_size,n_input*n_times)
sess.run(train, feed_dict={x: batch_xs, y: batch_ys})
acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})
print('Iter:' + str(epoch + 1) + 'Testing accuracy=' + str(acc))
if __name__ == '__main__':
main()
以MNIST手写数字识别数据集为例,原数据集中每个图片大小均为[28,28],我们设置其中的batch_size为100,原始输入数据就是[100,28*28],进行reshape操作为模型可以处理的数据[100,28,28]。
最关键的是tf.contrib.rnn.BasicLSTMCell(n_hidden_units)中n_hidden_units的含义,查阅资料后得知为网络输出的向量维数。RNN中我们每个step输入28维的数据,每个step输出100维的数据,output输出每个step的结果,所以最后output的shape为[100,28,100],finall_state为包含两个状态(c_state,h_state)的元组,其中每个状态仅含有最后一个step的数据,所以两者的shape都为[100,100],其中m_state的内容和output中每个batch最后一行(也就是最后一个step)一样。
参考链接https://xdrush.github.io/2018/02/12/RNN原理详解以及tensorflow中的RNN实现/