TensorFlow实现多层LSTM识别MNIST手写字,多层LSTM下state和output的关系

其他内容

https://blog.csdn.net/huqinweI987/article/details/83155110

 

输入格式:batch_size*784改成batch_size*28*28,28个序列,内容是一行的28个灰度数值。

让神经网络逐行扫描一个手写字体图案,总结各行特征,通过时间序列串联起来,最终得出结论。

网络定义:单独定义一个获取单元的函数,便于在MultiRNNCell中调用,创建多层LSTM网络

def get_a_cell(i):
    lstm_cell =rnn.BasicLSTMCell(num_units=HIDDEN_CELL, forget_bias = 1.0, state_is_tuple = True, name = 'layer_%s'%i)
    print(type(lstm_cell))
    dropout_wrapped = rnn.DropoutWrapper(cell = lstm_cell, input_keep_prob = 1.0, output_keep_prob = keep_prob)
    return dropout_wrapped

multi_lstm = rnn.MultiRNNCell(cells = [get_a_cell(i) for i in range(LSTM_LAYER)],
                              state_is_tuple=True)#tf.nn.rnn_cell.MultiRNNCell

多层RNN下state和单层RNN有所不同,多了些细节,每一层都是一个cell,每一个cell都有自己的state,每一层都对应一个LSTMStateTuple(本例是分类预测,所以只用到最后一层的输出,但是不代表其他情况不需要使用中间层的状态)。

cell之间是串联的,-1是最后一层的state,等价于单层下的output,我这里建了三层,所以-1和2相等:


outputs, state = tf.nn.dynamic_rnn(multi_lstm, inputs = tf_x_reshaped, initial_state = init_state, time_major = False)
print('state:',state)
print('state[0]:',state[0])#layer 0's LSTMStateTuple
print('state[1]:',state[1])#layer 1's LSTMStateTuple
print('state[2]:',state[2])#layer 2's LSTMStateTuple
print('state[-1]:',state[-1])#layer 2's LSTMStateTuple

state: (LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_3:0' shape=(32, 256) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_4:0' shape=(32, 256) dtype=float32>), LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_5:0' shape=(32, 256) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_6:0' shape=(32, 256) dtype=float32>), LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_7:0' shape=(32, 256) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_8:0' shape=(32, 256) dtype=float32>))
state[0]: LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_3:0' shape=(32, 256) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_4:0' shape=(32, 256) dtype=float32>)
state[1]: LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_5:0' shape=(32, 256) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_6:0' shape=(32, 256) dtype=float32>)
state[2]: LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_7:0' shape=(32, 256) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_8:0' shape=(32, 256) dtype=float32>)
state[-1]: LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_7:0' shape=(32, 256) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_8:0' shape=(32, 256) dtype=float32>)

下边是outputs和states的对比:outputs对应state_2,又因为这里做的是类型预测,是Nvs1模型,且time_major是False,第0维是batch,要取时间序列的最后一个输出,用[:,-1,:],可以看到,是全相等的。


outputs, state = tf.nn.dynamic_rnn(multi_lstm, inputs = tf_x_reshaped, initial_state = init_state, time_major = False)
h_state_0 = state[0][1]
h_state_1 = state[1][1]
h_state = state[-1][1]
h_state_2 = h_state



        _, loss_,outputs_, state_, h_state_0_, h_state_1_, h_state_2_ = \
            sess.run([train_op, cross_entropy,outputs, state, h_state_0, h_state_1, h_state_2], {tf_x:x, tf_y:y, keep_prob:1.0})


        print('h_state_2_ == outputs_[:,-1,:]:', h_state_2_ == outputs_[:,-1,:])


h_state_2_ == outputs_[:,-1,:]: [[ True  True  True ...  True  True  True]
 [ True  True  True ...  True  True  True]
 [ True  True  True ...  True  True  True]
 ...
 [ True  True  True ...  True  True  True]
 [ True  True  True ...  True  True  True]
 [ True  True  True ...  True  True  True]]

 

最后处理一下输出:LSTM的接口为了使用方便,输入输出是等维度的,不可设置,隐藏单元这里设置的256,需要做一个转换,转换为10维输出,最终对手写数字进行分类预测。

#prediction and loss
W = tf.Variable(initial_value = tf.truncated_normal([HIDDEN_CELL, CLASS_NUM], stddev = 0.1 ), dtype = tf.float32)
print(W)
b = tf.Variable(initial_value = tf.constant(0.1, shape = [CLASS_NUM]), dtype = tf.float32)
predictions = tf.nn.softmax(tf.matmul(h_state, W) + b)
#sum   -ylogy^
cross_entropy = -tf.reduce_sum(tf_y * tf.log(predictions))

 

完整代码:

https://github.com/huqinwei/tensorflow_demo/blob/master/lstm_mnist/multi_lstm_state_and_output.py

 

 

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值