关于tensorflow: stack_bidirectional_dynamic_rnn、bidirecitonal_dynamic_rnn函数中sequence_length的理解

关于tensorflow: stack_bidirectional_dynamic_rnn、bidirecitonal_dynamic_rnn函数中sequence_length参数的理解


最近因为做毕设,代码中用到了 stack_bidirectional_dynamic_rnn这个API 。对于其中的sequence_length纠结了两三天,不知道到底有没有必要传入这个参数。昨天读了相关源代码,懂了个大概。
通过阅读源码可知,该函数内部通过for循环调用了 bidirectional_dynamic_rnn函数,放上相关源代码( 链接):

def stack_bidirectional_dynamic_rnn(cells_fw,
                                    cells_bw,
                                    inputs,
                                    initial_states_fw=None,
                                    initial_states_bw=None,
                                    dtype=None,
                                    sequence_length=None,
                                    parallel_iterations=None,
                                    time_major=False,
                                    scope=None):
  """
  ...
  Args:
    sequence_length: (optional) An int32/int64 vector, size `[batch_size]`,
      containing the actual lengths for each of the sequences.
      ...
"""
 ...
  states_fw = []
  states_bw = []
  prev_layer = inputs

  with vs.variable_scope(scope or "stack_bidirectional_rnn"):
    for i, (cell_fw, cell_bw) in enumerate(zip(cells_fw, cells_bw)):
      initial_state_fw = None
      initial_state_bw = None
      if initial_states_fw:
        initial_state_fw = initial_states_fw[i]
      if initial_states_bw:
        initial_state_bw = initial_states_bw[i]

      with vs.variable_scope("cell_%d" % i):
        outputs, (state_fw, state_bw) = rnn.bidirectional_dynamic_rnn(
            cell_fw,
            cell_bw,
            prev_layer,
            initial_state_fw=initial_state_fw,
            initial_state_bw=initial_state_bw,
            sequence_length=sequence_length,
            parallel_iterations=parallel_iterations,
            dtype=dtype,
            time_major=time_major)
        # Concat the outputs to create the new input.
        prev_layer = array_ops.concat(outputs, 2)
      states_fw.append(state_fw)
      states_bw.append(state_bw)

  return prev_layer, tuple(states_fw), tuple(states_bw)

可以看到,注释中写的是sequence_length是optional,但是实际debug时,发现如果不加sequence_length这个参数,输出的output中,某个数据[max_time, layers_output] 超过实际序列长度后,是不同的output,还是在向后传递计算。所以在模型训练时,这有可能会影响最后的prediction。

现在来看看bidirectional_dynamic_rnn链接)中,对于sequence_length参数的处理,这里注意该函数内部调用的是dynamic_rnn。

def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None,
                              initial_state_fw=None, initial_state_bw=None,
                              dtype=None, parallel_iterations=None,
                              swap_memory=False, time_major=False, scope=None):
 ...
 Args:
    sequence_length: (optional) An int32/int64 vector, size `[batch_size]`,
      containing the actual lengths for each of the sequences in the batch.
      If not provided, all batch entries are assumed to be full sequences; and
      time reversal is applied from time `0` to `max_time` for each sequence.
...
  with vs.variable_scope(scope or "bidirectional_rnn"):
    # Forward direction
    with vs.variable_scope("fw") as fw_scope:
      output_fw, output_state_fw = dynamic_rnn(
          cell=cell_fw, inputs=inputs, sequence_length=sequence_length,
          initial_state=initial_state_fw, dtype=dtype,
          parallel_iterations=parallel_iterations, swap_memory=swap_memory,
          time_major=time_major, scope=fw_scope)

前向传播部分,是把sequence_length传入dynamic_rnn。

现在重点来了,后向传播部分,需要先把input进行reverse,再计算dynamic_rnn。

    # Backward direction
    if not time_major:
      time_axis = 1
      batch_axis = 0
    else:
      time_axis = 0
      batch_axis = 1

    def _reverse(input_, seq_lengths, seq_axis, batch_axis):
      if seq_lengths is not None:
        return array_ops.reverse_sequence(
            input=input_, seq_lengths=seq_lengths,
            seq_axis=seq_axis, batch_axis=batch_axis)
      else:
        return array_ops.reverse(input_, axis=[seq_axis])

    with vs.variable_scope("bw") as bw_scope:

      def _map_reverse(inp):
        return _reverse(
            inp,
            seq_lengths=sequence_length,
            seq_axis=time_axis,
            batch_axis=batch_axis)

      inputs_reverse = nest.map_structure(_map_reverse, inputs)
      tmp, output_state_bw = dynamic_rnn(
          cell=cell_bw, inputs=inputs_reverse, sequence_length=sequence_length,
          initial_state=initial_state_bw, dtype=dtype,
          parallel_iterations=parallel_iterations, swap_memory=swap_memory,
          time_major=time_major, scope=bw_scope)

后向传播部分,对input进行reverse,是经过_map_reverse函数,也即_reverse函数完成的。在_reverse函数中,就有对sequence_length参数的处理,如果没有输入这个参数,通过array_ops.reverse函数翻转整个序列。那么对于输入长短不一的数据,padding部分的0,就会在后向传播一开始输入进去;而输入了sequence_length参数的情况下,是将其传入array_ops.reverse_sequence函数进行翻转。

先看看两个翻转函数的区别:

import tensorflow as tf
import numpy as np

with tf.Session() as sess:
    # a shape[3, 4, 2]
    a = np.array([[[1, 2], [2, 1], [4, 3], [0, 0]],
                  [[2, 1], [3, 4], [0, 0], [0, 0]],
                  [[3, 5], [1, 3], [4, 6], [5, 2]]])

    seq_length = [3, 2, 4]
    b = tf.reverse_sequence(a, seq_length, 1, 0)
    c = tf.reverse(a, axis=[1])
    print(sess.run(b))
    print(sess.run(c))
    
输出:
[[[4 3]
  [2 1]
  [1 2]
  [0 0]]
 [[3 4]
  [2 1]
  [0 0]
  [0 0]]
 [[5 2]
  [4 6]
  [1 3]
  [3 5]]]
[[[0 0]
  [4 3]
  [2 1]
  [1 2]]
 [[0 0]
  [0 0]
  [3 4]
  [2 1]]
 [[5 2]
  [4 6]
  [1 3]
  [3 5]]]

能看到reverse_sequence是只翻转实际长度,0放在最后,reverse是直接全部翻转。

所以,得到第一个结论,stack_bidirectional_dynamic_rnnbidirecitonal_dynamic_rnn函数对sequence_length参数的处理在于使用了两个不同的翻转函数:reverse_sequence和reverse。

那么,sequence_length的输入与否会对stack_bidirectional_dynamic_rnn、bidirecitonal_dynamic_rnn的输出造成什么影响呢?

笔者看来,有两个方面对输出有潜在影响:

  • 一个是正向序列梯度计算时,输入的不足最大长度的data,在计算到实际长度后,还会对后面padding为0的time_step进行计算,生成不必要的错误数据,使训练时长增加;

  • 二则是由于是双向RNN,在反向序列计算中,reverse后,前面padding为0的time_step计算都是无用的,到了真正有数据时才开始计算。

现在还需要搞清楚,是否输入sequence_length是否会对计算结果准确度造成影响?

从基本的开始,先看sequence_length对dynamic_rnn函数输出的影响:

with tf.Session() as sess:
    X = np.random.randn(2, 10, 8)
    # 第二个example长度为6
    X[1, 6:] = 0.0
    # X[1, :6] = 0.0
    X_lengths = [10, 6]

    cell = tf.nn.rnn_cell.LSTMCell(num_units=5)
    outputs, last_states = tf.nn.dynamic_rnn(
        cell=cell,
        dtype=tf.float64,
        sequence_length=X_lengths,
        inputs=X)
    outputs1, last_states1 = tf.nn.dynamic_rnn(
        cell=cell,
        dtype=tf.float64,
        inputs=X)
    sess.run(tf.global_variables_initializer())
    print(sess.run(outputs))
    print(sess.run(outputs1))

输出:
[[[ 0.3664112  -0.16582295  0.03347532 -0.1209446  -0.16412497]
  [ 0.35257143 -0.05686     0.03231346 -0.13748774 -0.19436866]
  [ 0.53990554 -0.0629166  -0.01856082 -0.08934467 -0.19854734]
  [ 0.62018885 -0.14675617 -0.23786601 -0.18285464 -0.30012293]
  [ 0.29875151 -0.07205287 -0.09079972 -0.32260035 -0.07121483]
  [ 0.12284296  0.05494323 -0.26602965 -0.31480686 -0.00965394]
  [-0.04321202  0.10846491 -0.15792417  0.04715865  0.16775026]
  [ 0.14991799  0.24498104 -0.22123176 -0.16537638  0.17049721]
  [ 0.22538293  0.21274618 -0.30151976 -0.2806619   0.07987886]
  [ 0.16169614 -0.09109676 -0.23507253 -0.2256854   0.08532138]]
 [[ 0.03130918 -0.00153811  0.05989608  0.00258946  0.12982402]
  [-0.02059972  0.24180047 -0.01174238  0.20337912  0.11303384]
  [ 0.04525809  0.25172152 -0.12350006  0.09227578  0.25866815]
  [ 0.13557124  0.3672627  -0.3223366  -0.05648763  0.20540585]
  [ 0.18130031  0.20992908 -0.09175851  0.07085576  0.23343628]
  [ 0.34051405  0.28743254 -0.10791713 -0.09506878  0.12702959]
  [ 0.          0.          0.          0.          0.        ]
  [ 0.          0.          0.          0.          0.        ]
  [ 0.          0.          0.          0.          0.        ]
  [ 0.          0.          0.          0.          0.        ]]]
[[[ 0.3664112  -0.16582295  0.03347532 -0.1209446  -0.16412497]
  [ 0.35257143 -0.05686     0.03231346 -0.13748774 -0.19436866]
  [ 0.53990554 -0.0629166  -0.01856082 -0.08934467 -0.19854734]
  [ 0.62018885 -0.14675617 -0.23786601 -0.18285464 -0.30012293]
  [ 0.29875151 -0.07205287 -0.09079972 -0.32260035 -0.07121483]
  [ 0.12284296  0.05494323 -0.26602965 -0.31480686 -0.00965394]
  [-0.04321202  0.10846491 -0.15792417  0.04715865  0.16775026]
  [ 0.14991799  0.24498104 -0.22123176 -0.16537638  0.17049721]
  [ 0.22538293  0.21274618 -0.30151976 -0.2806619   0.07987886]
  [ 0.16169614 -0.09109676 -0.23507253 -0.2256854   0.08532138]]
 [[ 0.03130918 -0.00153811  0.05989608  0.00258946  0.12982402]
  [-0.02059972  0.24180047 -0.01174238  0.20337912  0.11303384]
  [ 0.04525809  0.25172152 -0.12350006  0.09227578  0.25866815]
  [ 0.13557124  0.3672627  -0.3223366  -0.05648763  0.20540585]
  [ 0.18130031  0.20992908 -0.09175851  0.07085576  0.23343628]
  [ 0.34051405  0.28743254 -0.10791713 -0.09506878  0.12702959]
  [ 0.20726339  0.2773752  -0.18121164 -0.06436528  0.1766949 ]
  [ 0.19274723  0.24523639 -0.14879227 -0.05048222  0.20608781]
  [ 0.17663077  0.22035726 -0.12688077 -0.03431938  0.22335574]
  [ 0.16139967  0.19805899 -0.11295674 -0.01820427  0.23194738]]]

对比可知,输入sequence_length情况下,实际长度之后的时间步不再计算。

如果翻转整个序列,前6个时间步为0,dynamic_rnn对应的时间步计算结果均为0。

X[1, :6] = 0.0
	outputs1, last_states1 = tf.nn.dynamic_rnn(
        cell=cell,
        dtype=tf.float64,
        inputs=X)
    sess.run(tf.global_variables_initializer())
    print(sess.run(outputs1))

输出:
[[[ 0.04069017 -0.23771831  0.045805    0.09560217  0.05502168]
  [ 0.09332863 -0.04726055  0.11690663  0.08879598  0.25465747]
  [ 0.19825957 -0.28982961  0.08313738  0.20992825  0.02033359]
  [ 0.15471266 -0.56485913  0.05661117  0.20066938 -0.03992696]
  [ 0.45740853 -0.0688219   0.34996705  0.07094942 -0.12920683]
  [ 0.45969932  0.11180939 -0.07534739 -0.03005722 -0.02074522]
  [ 0.25090551  0.03269829 -0.25112071 -0.28449897 -0.03996906]
  [ 0.00671926  0.04019763 -0.01132277 -0.06963076  0.00314332]
  [ 0.02276943  0.24375711 -0.18096459  0.10516443 -0.0245148 ]
  [-0.07252719  0.21662108 -0.06866634  0.09010588  0.04821544]]
 [[ 0.          0.          0.          0.          0.        ]
  [ 0.          0.          0.          0.          0.        ]
  [ 0.          0.          0.          0.          0.        ]
  [ 0.          0.          0.          0.          0.        ]
  [ 0.          0.          0.          0.          0.        ]
  [ 0.          0.          0.          0.          0.        ]
  [ 0.15381835  0.05551223  0.10847534  0.11666546 -0.10370024]
  [ 0.15700091 -0.1763513   0.16971487  0.0940137  -0.08755241]
  [-0.00314326 -0.16764532  0.13892345  0.05829588 -0.2285242 ]
  [ 0.16221322 -0.24974877 -0.17015645 -0.08633168 -0.05848374]]]

现在看一下bidirectional_dynamic_rnn函数的输出:

with tf.Session() as sess:
    cell = tf.nn.rnn_cell.LSTMCell(num_units=5)
    outputs, last_states = tf.nn.bidirectional_dynamic_rnn(
        cell_fw=cell,
        cell_bw=cell,
        dtype=tf.float64,
        sequence_length=X_lengths,
        inputs=X)
    outputs1, last_states1 = tf.nn.bidirectional_dynamic_rnn(
        cell_fw=cell,
        cell_bw=cell,
        dtype=tf.float64,
        inputs=X)
    sess.run(tf.global_variables_initializer())
    print(sess.run(outputs))
    print(sess.run(outputs1))

输出:
(array([[[ 1.03637534e-01, -7.10830000e-02, -6.15507404e-02,
          1.82409063e-01, -1.27943658e-02],
        [-1.71373530e-01, -1.63462122e-01,  5.82945984e-02,
          2.30467173e-01, -5.47676139e-02],
        [ 5.13349422e-02, -1.82497689e-01, -1.75145553e-01,
          1.53907335e-01, -7.97362981e-02],
        [ 1.06949741e-01, -4.01306522e-02, -2.10730327e-01,
          1.09050034e-01, -2.32756766e-01],
        [ 4.60590349e-02, -4.58901677e-02, -1.58516232e-01,
          2.01659355e-02,  2.69986613e-02],
        [ 1.34656088e-01,  1.27036696e-01, -1.17810220e-02,
         -3.33886708e-02,  4.41047476e-01],
        [ 2.22404726e-01,  1.59448161e-01, -1.01579624e-01,
         -5.48369031e-02,  1.06230121e-01],
        [ 2.30960285e-01, -3.90051280e-02, -2.24918456e-02,
         -2.35480280e-02,  1.52811464e-02],
        [ 2.10968050e-01,  1.83505109e-01,  1.78626573e-01,
         -2.55444308e-01,  4.00898776e-01],
        [ 2.58521311e-01,  3.29425438e-01,  1.17003287e-01,
         -2.77698999e-02,  1.49407274e-01]],
       [[-1.58240424e-01, -8.11458422e-02,  4.00896978e-02,
          1.04382802e-01, -3.94471746e-03],
        [-3.35112999e-01, -9.55060361e-02,  2.90027134e-01,
         -1.81891039e-01,  9.10953666e-03],
        [-1.67419074e-01, -4.00547833e-04,  1.95651535e-01,
         -2.47421707e-01,  5.14707873e-02],
        [-3.55542561e-02,  1.83676786e-01,  1.79457375e-01,
         -4.30804504e-01,  8.66083426e-02],
        [-2.09451275e-01,  2.34769980e-01,  3.52470628e-01,
         -1.28957238e-01,  2.21493636e-01],
        [-3.41874621e-01, -2.51958521e-02,  5.17341140e-01,
         -3.61904359e-01,  2.32170058e-01],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00]]]), array([[[ 0.16441463, -0.19302047, -0.09601584,  0.26801333,
         -0.03560607],
        [ 0.17595892, -0.15906587, -0.04336751,  0.20887378,
         -0.03159793],
        [ 0.2235174 , -0.05801825, -0.17605939,  0.091533  ,
          0.0131942 ],
        [ 0.11473129,  0.09345286,  0.01537781, -0.01188187,
          0.14568516],
        [ 0.03106931,  0.08806389,  0.40831446, -0.10394683,
          0.36716515],
        [ 0.2329871 ,  0.15022   ,  0.11296073, -0.07147051,
          0.4857348 ],
        [ 0.20390796,  0.01588418,  0.03380659, -0.03090537,
          0.09930256],
        [ 0.24167838, -0.00490647,  0.16429447, -0.00820602,
          0.08831998],
        [ 0.23332149,  0.18020251,  0.12465346, -0.1750562 ,
          0.40673199],
        [ 0.15236341, -0.0623327 , -0.07389674,  0.0791574 ,
          0.02982562]],
       [[-0.36612451, -0.09818593,  0.37779096, -0.10673546,
          0.09468116],
        [-0.24289844, -0.00789502,  0.43728569, -0.36032157,
          0.18526209],
        [-0.05659099,  0.23216722,  0.28674557, -0.31059676,
          0.15522032],
        [-0.22365649,  0.18580737,  0.2019476 , -0.41131144,
          0.19063189],
        [-0.38062693,  0.02047546,  0.36591703, -0.11403186,
          0.28162637],
        [-0.22493386, -0.06232192,  0.41815278, -0.2991254 ,
          0.1451596 ],
        [ 0.        ,  0.        ,  0.        ,  0.        ,
          0.        ],
        [ 0.        ,  0.        ,  0.        ,  0.        ,
          0.        ],
        [ 0.        ,  0.        ,  0.        ,  0.        ,
          0.        ],
        [ 0.        ,  0.        ,  0.        ,  0.        ,
          0.        ]]]))
(array([[[ 1.03637534e-01, -7.10830000e-02, -6.15507404e-02,
          1.82409063e-01, -1.27943658e-02],
        [-1.71373530e-01, -1.63462122e-01,  5.82945984e-02,
          2.30467173e-01, -5.47676139e-02],
        [ 5.13349422e-02, -1.82497689e-01, -1.75145553e-01,
          1.53907335e-01, -7.97362981e-02],
        [ 1.06949741e-01, -4.01306522e-02, -2.10730327e-01,
          1.09050034e-01, -2.32756766e-01],
        [ 4.60590349e-02, -4.58901677e-02, -1.58516232e-01,
          2.01659355e-02,  2.69986613e-02],
        [ 1.34656088e-01,  1.27036696e-01, -1.17810220e-02,
         -3.33886708e-02,  4.41047476e-01],
        [ 2.22404726e-01,  1.59448161e-01, -1.01579624e-01,
         -5.48369031e-02,  1.06230121e-01],
        [ 2.30960285e-01, -3.90051280e-02, -2.24918456e-02,
         -2.35480280e-02,  1.52811464e-02],
        [ 2.10968050e-01,  1.83505109e-01,  1.78626573e-01,
         -2.55444308e-01,  4.00898776e-01],
        [ 2.58521311e-01,  3.29425438e-01,  1.17003287e-01,
         -2.77698999e-02,  1.49407274e-01]],
       [[-1.58240424e-01, -8.11458422e-02,  4.00896978e-02,
          1.04382802e-01, -3.94471746e-03],
        [-3.35112999e-01, -9.55060361e-02,  2.90027134e-01,
         -1.81891039e-01,  9.10953666e-03],
        [-1.67419074e-01, -4.00547833e-04,  1.95651535e-01,
         -2.47421707e-01,  5.14707873e-02],
        [-3.55542561e-02,  1.83676786e-01,  1.79457375e-01,
         -4.30804504e-01,  8.66083426e-02],
        [-2.09451275e-01,  2.34769980e-01,  3.52470628e-01,
         -1.28957238e-01,  2.21493636e-01],
        [-3.41874621e-01, -2.51958521e-02,  5.17341140e-01,
         -3.61904359e-01,  2.32170058e-01],
        [-2.51546442e-01, -1.18885843e-02,  2.96923579e-01,
         -1.86115662e-01,  2.16003435e-01],
        [-2.14802264e-01,  7.02766447e-04,  2.61682611e-01,
         -9.91503917e-02,  1.88783995e-01],
        [-1.87142315e-01,  1.90646907e-03,  2.19077987e-01,
         -3.69484125e-02,  1.55271455e-01],
        [-1.63632591e-01, -6.82563058e-04,  1.78274888e-01,
          2.33492655e-03,  1.24669109e-01]]]), array([[[ 0.16441463, -0.19302047, -0.09601584,  0.26801333,
         -0.03560607],
        [ 0.17595892, -0.15906587, -0.04336751,  0.20887378,
         -0.03159793],
        [ 0.2235174 , -0.05801825, -0.17605939,  0.091533  ,
          0.0131942 ],
        [ 0.11473129,  0.09345286,  0.01537781, -0.01188187,
          0.14568516],
        [ 0.03106931,  0.08806389,  0.40831446, -0.10394683,
          0.36716515],
        [ 0.2329871 ,  0.15022   ,  0.11296073, -0.07147051,
          0.4857348 ],
        [ 0.20390796,  0.01588418,  0.03380659, -0.03090537,
          0.09930256],
        [ 0.24167838, -0.00490647,  0.16429447, -0.00820602,
          0.08831998],
        [ 0.23332149,  0.18020251,  0.12465346, -0.1750562 ,
          0.40673199],
        [ 0.15236341, -0.0623327 , -0.07389674,  0.0791574 ,
          0.02982562]],
       [[-0.36612451, -0.09818593,  0.37779096, -0.10673546,
          0.09468116],
        [-0.24289844, -0.00789502,  0.43728569, -0.36032157,
          0.18526209],
        [-0.05659099,  0.23216722,  0.28674557, -0.31059676,
          0.15522032],
        [-0.22365649,  0.18580737,  0.2019476 , -0.41131144,
          0.19063189],
        [-0.38062693,  0.02047546,  0.36591703, -0.11403186,
          0.28162637],
        [-0.22493386, -0.06232192,  0.41815278, -0.2991254 ,
          0.1451596 ],
        [ 0.        ,  0.        ,  0.        ,  0.        ,
          0.        ],
        [ 0.        ,  0.        ,  0.        ,  0.        ,
          0.        ],
        [ 0.        ,  0.        ,  0.        ,  0.        ,
          0.        ],
        [ 0.        ,  0.        ,  0.        ,  0.        ,
          0.        ]]]))

从输出能看到,对长度为10的序列,sequence_length参数不会造成影响。
而长度为6的第二个序列,输入了sequence_length后,前向状态传播输出的后面padding的时间步为0,没传参的后面就还在计算。
根据函数源码

    output_bw = _reverse(
          tmp, seq_lengths=sequence_length,
          seq_axis=time_axis, batch_axis=batch_axis)

可知,后向状态传播的输出结果相同,与实验结果一致。

终于到了stack_bidirectional_dynamic_rnn函数:

def lstm_cell(lstm_unit=5):
    cell = tf.nn.rnn_cell.LSTMCell(num_units=lstm_unit)
    return cell
    
with tf.Session() as sess:    
   # 两层双向LSTM
   cells_fw = tf.contrib.rnn.MultiRNNCell([lstm_cell() for _ in range(2)])
    cells_bw = tf.contrib.rnn.MultiRNNCell([lstm_cell() for _ in range(2)])
    outputs, output_state_fw, output_state_bw = \
        tf.contrib.rnn.stack_bidirectional_dynamic_rnn(
            cells_fw=[cells_fw], cells_bw=[cells_bw], inputs=X, sequence_length=X_lengths, dtype=tf.float64)
    outputs1, output_state_fw1, output_state_bw1 = \
        tf.contrib.rnn.stack_bidirectional_dynamic_rnn(
            cells_fw=[cells_fw], cells_bw=[cells_bw], inputs=X, dtype=tf.float64)
    sess.run(tf.global_variables_initializer())
    print(sess.run(outputs))
    print(sess.run(outputs1))
输出:
[[[-1.28601883e-02 -1.22699760e-04  2.87583047e-02 -1.12176756e-02
   -8.02404394e-03 -5.23670159e-02  1.56707537e-01  1.06681455e-01
    2.62892937e-02 -1.14040429e-01]
  [-3.43349540e-02 -2.66737734e-03  6.48235389e-02 -1.54243659e-02
   -7.99647407e-03 -7.89364081e-02  1.52510327e-01  1.01263909e-01
    2.12051295e-02 -9.87228364e-02]
  [-7.06332778e-02 -1.31020952e-02  1.07337490e-01 -3.73839672e-02
   -1.85424702e-02 -9.77438210e-02  1.35022125e-01  5.90582636e-02
    1.63908052e-03 -7.43389943e-02]
  [-8.23076663e-02 -3.82813864e-02  1.37970988e-01 -2.99934549e-02
   -2.68845690e-02 -1.02355369e-01  9.12888953e-02 -3.53093145e-04
   -3.53520397e-02 -5.78286482e-02]
  [-9.50498624e-02 -4.60312526e-02  1.67575362e-01 -2.74771200e-02
   -8.43013199e-03 -8.06010196e-02  5.54779549e-02 -4.75758513e-02
   -6.47748359e-02 -2.75317154e-02]
  [-7.16351738e-02 -4.41277788e-02  1.45701763e-01  1.24855432e-03
    1.46169942e-02 -7.60465568e-02  5.93749022e-02 -6.17292026e-02
   -5.86224445e-02 -2.66687336e-02]
  [-6.16778231e-02 -4.72333447e-02  1.11211251e-01  3.88144082e-02
    4.30879841e-02 -3.65799040e-02  5.09810184e-02 -3.84721677e-02
   -2.77666367e-02 -1.34022960e-02]
  [-4.86744743e-02 -2.13668975e-02  6.73850765e-02  2.83607342e-02
    3.75809411e-02 -2.51348881e-02  2.83121865e-02 -2.06509712e-02
   -2.13905453e-02 -4.74981693e-03]
  [-3.49021170e-02 -1.49223217e-02  5.07296592e-02  3.44248124e-02
    1.76991886e-02 -8.80179158e-03 -8.46805269e-03 -2.06985742e-02
   -3.19633058e-02  9.95127577e-03]
  [-1.98075939e-02 -5.06471247e-03  2.12698655e-02  3.85639213e-02
    2.10663157e-02 -1.11487135e-02 -7.10696841e-03 -6.24988002e-03
   -1.48659066e-02  6.16684391e-03]]
 [[ 3.62599737e-03  2.94539921e-03 -2.58042726e-02  6.62572147e-03
    8.17662302e-03  3.21574858e-02  8.00241626e-02  2.88561398e-03
    2.25625696e-02 -2.97773596e-02]
  [ 1.24748018e-02  3.25943748e-02 -6.98699079e-02 -1.99415204e-02
    8.98645467e-03  2.75950596e-02  6.77296140e-02 -1.72919213e-03
    2.37468487e-02 -3.50077631e-02]
  [ 2.12325089e-02  3.92948087e-02 -9.18501935e-02 -2.87299468e-02
   -3.20147555e-04  1.46104383e-02  7.63723852e-02  1.26062855e-02
    2.64755854e-02 -2.93706603e-02]
  [ 2.27776689e-02  4.12789392e-02 -8.49281608e-02 -3.98009553e-02
   -2.13557998e-02 -1.36286022e-03  4.99668684e-02  2.34237759e-02
    2.02189041e-02 -2.05926547e-02]
  [ 1.99018813e-02  3.41225257e-02 -8.66679609e-02 -3.45312179e-02
   -1.78512544e-02 -4.82369387e-03  2.14085347e-02  3.65739249e-02
    9.21655481e-03 -1.30496637e-03]
  [ 2.62904309e-02  6.18586705e-02 -9.56436806e-02 -6.58077095e-02
   -2.72908745e-02  4.50642484e-03  1.30500698e-02  1.51553697e-02
    1.34644969e-02 -5.49530647e-03]
  [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
    0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
    0.00000000e+00  0.00000000e+00]
  [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
    0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
    0.00000000e+00  0.00000000e+00]
  [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
    0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
    0.00000000e+00  0.00000000e+00]
  [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
    0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
    0.00000000e+00  0.00000000e+00]]]
[[[-1.28601883e-02 -1.22699760e-04  2.87583047e-02 -1.12176756e-02
   -8.02404394e-03 -5.23670159e-02  1.56707537e-01  1.06681455e-01
    2.62892937e-02 -1.14040429e-01]
  [-3.43349540e-02 -2.66737734e-03  6.48235389e-02 -1.54243659e-02
   -7.99647407e-03 -7.89364081e-02  1.52510327e-01  1.01263909e-01
    2.12051295e-02 -9.87228364e-02]
  [-7.06332778e-02 -1.31020952e-02  1.07337490e-01 -3.73839672e-02
   -1.85424702e-02 -9.77438210e-02  1.35022125e-01  5.90582636e-02
    1.63908052e-03 -7.43389943e-02]
  [-8.23076663e-02 -3.82813864e-02  1.37970988e-01 -2.99934549e-02
   -2.68845690e-02 -1.02355369e-01  9.12888953e-02 -3.53093145e-04
   -3.53520397e-02 -5.78286482e-02]
  [-9.50498624e-02 -4.60312526e-02  1.67575362e-01 -2.74771200e-02
   -8.43013199e-03 -8.06010196e-02  5.54779549e-02 -4.75758513e-02
   -6.47748359e-02 -2.75317154e-02]
  [-7.16351738e-02 -4.41277788e-02  1.45701763e-01  1.24855432e-03
    1.46169942e-02 -7.60465568e-02  5.93749022e-02 -6.17292026e-02
   -5.86224445e-02 -2.66687336e-02]
  [-6.16778231e-02 -4.72333447e-02  1.11211251e-01  3.88144082e-02
    4.30879841e-02 -3.65799040e-02  5.09810184e-02 -3.84721677e-02
   -2.77666367e-02 -1.34022960e-02]
  [-4.86744743e-02 -2.13668975e-02  6.73850765e-02  2.83607342e-02
    3.75809411e-02 -2.51348881e-02  2.83121865e-02 -2.06509712e-02
   -2.13905453e-02 -4.74981693e-03]
  [-3.49021170e-02 -1.49223217e-02  5.07296592e-02  3.44248124e-02
    1.76991886e-02 -8.80179158e-03 -8.46805269e-03 -2.06985742e-02
   -3.19633058e-02  9.95127577e-03]
  [-1.98075939e-02 -5.06471247e-03  2.12698655e-02  3.85639213e-02
    2.10663157e-02 -1.11487135e-02 -7.10696841e-03 -6.24988002e-03
   -1.48659066e-02  6.16684391e-03]]
 [[ 3.62599737e-03  2.94539921e-03 -2.58042726e-02  6.62572147e-03
    8.17662302e-03  3.21574858e-02  8.00241626e-02  2.88561398e-03
    2.25625696e-02 -2.97773596e-02]
  [ 1.24748018e-02  3.25943748e-02 -6.98699079e-02 -1.99415204e-02
    8.98645467e-03  2.75950596e-02  6.77296140e-02 -1.72919213e-03
    2.37468487e-02 -3.50077631e-02]
  [ 2.12325089e-02  3.92948087e-02 -9.18501935e-02 -2.87299468e-02
   -3.20147555e-04  1.46104383e-02  7.63723852e-02  1.26062855e-02
    2.64755854e-02 -2.93706603e-02]
  [ 2.27776689e-02  4.12789392e-02 -8.49281608e-02 -3.98009553e-02
   -2.13557998e-02 -1.36286022e-03  4.99668684e-02  2.34237759e-02
    2.02189041e-02 -2.05926547e-02]
  [ 1.99018813e-02  3.41225257e-02 -8.66679609e-02 -3.45312179e-02
   -1.78512544e-02 -4.82369387e-03  2.14085347e-02  3.65739249e-02
    9.21655481e-03 -1.30496637e-03]
  [ 2.62904309e-02  6.18586705e-02 -9.56436806e-02 -6.58077095e-02
   -2.72908745e-02  4.50642484e-03  1.30500698e-02  1.51553697e-02
    1.34644969e-02 -5.49530647e-03]
  [ 3.05670502e-02  6.18263175e-02 -7.87989639e-02 -7.46175202e-02
   -3.49183715e-02  0.00000000e+00  0.00000000e+00  0.00000000e+00
    0.00000000e+00  0.00000000e+00]
  [ 3.33352693e-02  5.59230982e-02 -6.25883387e-02 -7.42776119e-02
   -3.54990111e-02  0.00000000e+00  0.00000000e+00  0.00000000e+00
    0.00000000e+00  0.00000000e+00]
  [ 3.38701860e-02  4.70675661e-02 -4.76290537e-02 -6.78331036e-02
   -3.13554307e-02  0.00000000e+00  0.00000000e+00  0.00000000e+00
    0.00000000e+00  0.00000000e+00]
  [ 3.23695247e-02  3.73589351e-02 -3.50100669e-02 -5.80015682e-02
   -2.49360807e-02  0.00000000e+00  0.00000000e+00  0.00000000e+00
    0.00000000e+00  0.00000000e+00]]]

可以看到输入sequence_length参数的stack_bidirectional_dynamic_rnn函数输出的padding部分计算结果为0,而没输入实际长度参数的输出,padding部分仍然有值。

通过以上实验,总算是弄清楚了sequence_length参数对于stack_bidirectional_dynamic_rnn、bidirectional_dynamic_rnn、dynamic_rnn函数结果的影响。虽然函数注释中都写的是optional,但个人感觉最好传入这个实际长度参数,避免可能的错误和中间不必要的计算。至于是否最终结果,要根据后续的具体代码处理而定。

后面有时间再看一下sequence_length参数对prediction和loss的影响。

作为一名很菜的程序媛,在CSDN上学到了很多很实用的知识技术,这是本人的第一篇博客,以后也要努力把敲代码过程中遇到的问题记录下来,提高coding能力。

  • 11
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值