双向变长RNN函数tf.nn.bidirectional_dynamic_rnn的输出分析

tf.nn.bidirectional_dynamic_rnn

为了避免无脑padding影响模型效果,bidirectional_dynamic_rnn函数可以实现变长输入且只计算有效长度,但是在输入的时候,还是需要padding齐。


outputs, final_state = tf.nn.bidirectional_dynamic_rnn(lstm1, lstm2, x,sequence_length=[2,3,1,2], dtype=tf.float32)

实例直观展示该函数的输出

x=[[[1,2,3],[2,3,4],[0,0,0]],
[[7,8,9],[2,5,7],[8,5,3]],
[[4,6,7],[0,0,0],[0,0,0]],
   [[2,6,9],[7,4,7],[0,0,0]]]

x=tf.convert_to_tensor(x,dtype=tf.float32)
print (x.shape)
lstm1 = tf.contrib.rnn.LSTMCell(5, state_is_tuple=True)
lstm2 = tf.contrib.rnn.LSTMCell(5, state_is_tuple=True)


outputs, final_state = tf.nn.bidirectional_dynamic_rnn(lstm1, lstm2, x,sequence_length=[2,3,1,2], dtype=tf.float32)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    o=sess.run(outputs)
    f=sess.run(final_state)
    print (o)
    print ('======================')
    print (f)

输出结果

(array([[[-4.7971986e-02,  2.6565358e-01, -5.0791793e-02, -6.3681910e-03,
         -1.4741863e-01],
        [-5.7430852e-02,  3.9343110e-01, -1.5554780e-01, -1.6064603e-02,
         -1.1796383e-01],
        [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
          0.0000000e+00]],

       [[-8.2660372e-05,  4.5020923e-01, -2.0235220e-01, -4.5891651e-03,
         -7.0640403e-03],
        [-3.2798066e-03,  6.6236269e-01, -3.0268517e-01, -8.6834300e-03,
         -5.9959032e-02],
        [-1.1065261e-02,  7.9541743e-02, -2.5284705e-01, -1.5045789e-01,
         -1.4271882e-02]],

       [[-1.1178922e-03,  4.8776790e-01, -1.5527356e-01, -4.9713873e-03,
         -3.1170087e-02],
        [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
          0.0000000e+00],
        [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
          0.0000000e+00]],

       [[-8.7980053e-04,  6.6924572e-01, -5.8024418e-02,  1.0198562e-03,
         -3.5060328e-02],
        [-3.5331668e-03,  4.2750126e-01, -2.8861091e-01, -4.6727475e-02,
         -8.3969785e-03],
        [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
          0.0000000e+00]]], dtype=float32), array([[[-0.46014848, -0.18694998, -0.33176026, -0.36103424,
          0.12055817],
        [-0.35464838, -0.10751288, -0.36990282, -0.21526965,
          0.05854482],
        [ 0.        ,  0.        ,  0.        ,  0.        ,
          0.        ]],

       [[-0.72168607,  0.04445596, -0.6784164 , -0.2931869 ,
          0.00137328],
        [-0.42270145,  0.12505652, -0.5798562 , -0.33974066,
          0.01590933],
        [-0.6525634 ,  0.21627644, -0.47753727, -0.18012388,
          0.01027563]],

       [[-0.37131876, -0.05870048, -0.580433  , -0.20945168,
          0.00916489],
        [ 0.        ,  0.        ,  0.        ,  0.        ,
          0.        ],
        [ 0.        ,  0.        ,  0.        ,  0.        ,
          0.        ]],

       [[-0.36822024, -0.10755689, -0.62205905, -0.3472318 ,
          0.00997364],
        [-0.67421496, -0.08016963, -0.5534432 , -0.3275094 ,
          0.01982484],
        [ 0.        ,  0.        ,  0.        ,  0.        ,
          0.        ]]], dtype=float32))
======================

可见outputs分别包括前向和后项LSTM结果,其是根据sequence_length来的,例如长度为2的seq其outputs输出只有两行是有效的,最后一行为0表示为参与运算。

(LSTMStateTuple(c=array([[-0.23832373,  0.7580961 , -0.2514931 , -0.15673895, -0.9515827 ],
       [-0.03692751,  1.2597698 , -0.805684  , -1.3808743 , -0.44871476],
       [-0.01143863,  0.85898864, -0.2325824 , -0.2767025 , -0.9141735 ],
       [-0.01752725,  1.1746776 , -0.44077516, -0.7668716 , -0.7264488 ]],
      dtype=float32), h=array([[-0.05743085,  0.3934311 , -0.1555478 , -0.0160646 , -0.11796383],
       [-0.01106526,  0.07954174, -0.25284705, -0.15045789, -0.01427188],
       [-0.00111789,  0.4877679 , -0.15527356, -0.00497139, -0.03117009],
       [-0.00353317,  0.42750126, -0.2886109 , -0.04672747, -0.00839698]],
      dtype=float32)), LSTMStateTuple(c=array([[-1.3982543 , -0.22885469, -0.53742784, -1.0023961 ,  0.35004577],
       [-2.755764  ,  0.04452498, -0.96565115, -2.3552978 ,  0.01877723],
       [-0.98055553, -0.05923914, -0.86219144, -0.8616287 ,  0.07237855],
       [-1.8098019 , -0.10917448, -0.8589754 , -1.8528795 ,  0.07483704]],
      dtype=float32), h=array([[-0.46014848, -0.18694998, -0.33176026, -0.36103424,  0.12055817],
       [-0.72168607,  0.04445596, -0.6784164 , -0.2931869 ,  0.00137328],
       [-0.37131876, -0.05870048, -0.580433  , -0.20945168,  0.00916489],
       [-0.36822024, -0.10755689, -0.62205905, -0.3472318 ,  0.00997364]],
      dtype=float32)))

可见最终隐层输出分别包括前向、后项的C和H。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值