其他内容
https://blog.csdn.net/huqinweI987/article/details/83155110
输入格式:batch_size*784改成batch_size*28*28,28个序列,内容是一行的28个灰度数值。
让神经网络逐行扫描一个手写字体图案,总结各行特征,通过时间序列串联起来,最终得出结论。
网络定义:单独定义一个获取单元的函数,便于在MultiRNNCell中调用,创建多层LSTM网络
def get_a_cell(i):
lstm_cell =rnn.BasicLSTMCell(num_units=HIDDEN_CELL, forget_bias = 1.0, state_is_tuple = True, name = 'layer_%s'%i)
print(type(lstm_cell))
dropout_wrapped = rnn.DropoutWrapper(cell = lstm_cell, input_keep_prob = 1.0, output_keep_prob = keep_prob)
return dropout_wrapped
multi_lstm = rnn.MultiRNNCell(cells = [get_a_cell(i) for i in range(LSTM_LAYER)],
state_is_tuple=True)#tf.nn.rnn_cell.MultiRNNCell
多层RNN下state和单层RNN有所不同,多了些细节,每一层都是一个cell,每一个cell都有自己的state,每一层都对应一个LSTMStateTuple(本例是分类预测,所以只用到最后一层的输出,但是不代表其他情况不需要使用中间层的状态)。
cell之间是串联的,-1是最后一层的state,等价于单层下的output,我这里建了三层,所以-1和2相等:
outputs, state = tf.nn.dynamic_rnn(multi_lstm, inputs = tf_x_reshaped, initial_state = init_state, time_major = False)
print('state:',state)
print('state[0]:',state[0])#layer 0's LSTMStateTuple
print('state[1]:',state[1])#layer 1's LSTMStateTuple
print('state[2]:',state[2])#layer 2's LSTMStateTuple
print('state[-1]:',state[-1])#layer 2's LSTMStateTuple
state: (LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_3:0' shape=(32, 256) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_4:0' shape=(32, 256) dtype=float32>), LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_5:0' shape=(32, 256) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_6:0' shape=(32, 256) dtype=float32>), LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_7:0' shape=(32, 256) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_8:0' shape=(32, 256) dtype=float32>))
state[0]: LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_3:0' shape=(32, 256) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_4:0' shape=(32, 256) dtype=float32>)
state[1]: LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_5:0' shape=(32, 256) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_6:0' shape=(32, 256) dtype=float32>)
state[2]: LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_7:0' shape=(32, 256) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_8:0' shape=(32, 256) dtype=float32>)
state[-1]: LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_7:0' shape=(32, 256) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_8:0' shape=(32, 256) dtype=float32>)
下边是outputs和states的对比:outputs对应state_2,又因为这里做的是类型预测,是Nvs1模型,且time_major是False,第0维是batch,要取时间序列的最后一个输出,用[:,-1,:],可以看到,是全相等的。
outputs, state = tf.nn.dynamic_rnn(multi_lstm, inputs = tf_x_reshaped, initial_state = init_state, time_major = False)
h_state_0 = state[0][1]
h_state_1 = state[1][1]
h_state = state[-1][1]
h_state_2 = h_state
_, loss_,outputs_, state_, h_state_0_, h_state_1_, h_state_2_ = \
sess.run([train_op, cross_entropy,outputs, state, h_state_0, h_state_1, h_state_2], {tf_x:x, tf_y:y, keep_prob:1.0})
print('h_state_2_ == outputs_[:,-1,:]:', h_state_2_ == outputs_[:,-1,:])
h_state_2_ == outputs_[:,-1,:]: [[ True True True ... True True True]
[ True True True ... True True True]
[ True True True ... True True True]
...
[ True True True ... True True True]
[ True True True ... True True True]
[ True True True ... True True True]]
最后处理一下输出:LSTM的接口为了使用方便,输入输出是等维度的,不可设置,隐藏单元这里设置的256,需要做一个转换,转换为10维输出,最终对手写数字进行分类预测。
#prediction and loss
W = tf.Variable(initial_value = tf.truncated_normal([HIDDEN_CELL, CLASS_NUM], stddev = 0.1 ), dtype = tf.float32)
print(W)
b = tf.Variable(initial_value = tf.constant(0.1, shape = [CLASS_NUM]), dtype = tf.float32)
predictions = tf.nn.softmax(tf.matmul(h_state, W) + b)
#sum -ylogy^
cross_entropy = -tf.reduce_sum(tf_y * tf.log(predictions))
完整代码:
https://github.com/huqinwei/tensorflow_demo/blob/master/lstm_mnist/multi_lstm_state_and_output.py