关于tensorflow: stack_bidirectional_dynamic_rnn、bidirecitonal_dynamic_rnn函数中sequence_length参数的理解
最近因为做毕设,代码中用到了 stack_bidirectional_dynamic_rnn这个API 。对于其中的sequence_length纠结了两三天,不知道到底有没有必要传入这个参数。昨天读了相关源代码,懂了个大概。
通过阅读源码可知,该函数内部通过for循环调用了 bidirectional_dynamic_rnn函数,放上相关源代码( 链接):
def stack_bidirectional_dynamic_rnn(cells_fw,
cells_bw,
inputs,
initial_states_fw=None,
initial_states_bw=None,
dtype=None,
sequence_length=None,
parallel_iterations=None,
time_major=False,
scope=None):
"""
...
Args:
sequence_length: (optional) An int32/int64 vector, size `[batch_size]`,
containing the actual lengths for each of the sequences.
...
"""
...
states_fw = []
states_bw = []
prev_layer = inputs
with vs.variable_scope(scope or "stack_bidirectional_rnn"):
for i, (cell_fw, cell_bw) in enumerate(zip(cells_fw, cells_bw)):
initial_state_fw = None
initial_state_bw = None
if initial_states_fw:
initial_state_fw = initial_states_fw[i]
if initial_states_bw:
initial_state_bw = initial_states_bw[i]
with vs.variable_scope("cell_%d" % i):
outputs, (state_fw, state_bw) = rnn.bidirectional_dynamic_rnn(
cell_fw,
cell_bw,
prev_layer,
initial_state_fw=initial_state_fw,
initial_state_bw=initial_state_bw,
sequence_length=sequence_length,
parallel_iterations=parallel_iterations,
dtype=dtype,
time_major=time_major)
# Concat the outputs to create the new input.
prev_layer = array_ops.concat(outputs, 2)
states_fw.append(state_fw)
states_bw.append(state_bw)
return prev_layer, tuple(states_fw), tuple(states_bw)
可以看到,注释中写的是sequence_length是optional,但是实际debug时,发现如果不加sequence_length这个参数,输出的output中,某个数据[max_time, layers_output] 超过实际序列长度后,是不同的output,还是在向后传递计算。所以在模型训练时,这有可能会影响最后的prediction。
现在来看看bidirectional_dynamic_rnn(链接)中,对于sequence_length参数的处理,这里注意该函数内部调用的是dynamic_rnn。
def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None,
initial_state_fw=None, initial_state_bw=None,
dtype=None, parallel_iterations=None,
swap_memory=False, time_major=False, scope=None):
...
Args:
sequence_length: (optional) An int32/int64 vector, size `[batch_size]`,
containing the actual lengths for each of the sequences in the batch.
If not provided, all batch entries are assumed to be full sequences; and
time reversal is applied from time `0` to `max_time` for each sequence.
...
with vs.variable_scope(scope or "bidirectional_rnn"):
# Forward direction
with vs.variable_scope("fw") as fw_scope:
output_fw, output_state_fw = dynamic_rnn(
cell=cell_fw, inputs=inputs, sequence_length=sequence_length,
initial_state=initial_state_fw, dtype=dtype,
parallel_iterations=parallel_iterations, swap_memory=swap_memory,
time_major=time_major, scope=fw_scope)
前向传播部分,是把sequence_length传入dynamic_rnn。
现在重点来了,后向传播部分,需要先把input进行reverse,再计算dynamic_rnn。
# Backward direction
if not time_major:
time_axis = 1
batch_axis = 0
else:
time_axis = 0
batch_axis = 1
def _reverse(input_, seq_lengths, seq_axis, batch_axis):
if seq_lengths is not None:
return array_ops.reverse_sequence(
input=input_, seq_lengths=seq_lengths,
seq_axis=seq_axis, batch_axis=batch_axis)
else:
return array_ops.reverse(input_, axis=[seq_axis])
with vs.variable_scope("bw") as bw_scope:
def _map_reverse(inp):
return _reverse(
inp,
seq_lengths=sequence_length,
seq_axis=time_axis,
batch_axis=batch_axis)
inputs_reverse = nest.map_structure(_map_reverse, inputs)
tmp, output_state_bw = dynamic_rnn(
cell=cell_bw, inputs=inputs_reverse, sequence_length=sequence_length,
initial_state=initial_state_bw, dtype=dtype,
parallel_iterations=parallel_iterations, swap_memory=swap_memory,
time_major=time_major, scope=bw_scope)
后向传播部分,对input进行reverse,是经过_map_reverse函数,也即_reverse函数完成的。在_reverse函数中,就有对sequence_length参数的处理,如果没有输入这个参数,通过array_ops.reverse函数翻转整个序列。那么对于输入长短不一的数据,padding部分的0,就会在后向传播一开始输入进去;而输入了sequence_length参数的情况下,是将其传入array_ops.reverse_sequence函数进行翻转。
先看看两个翻转函数的区别:
import tensorflow as tf
import numpy as np
with tf.Session() as sess:
# a shape[3, 4, 2]
a = np.array([[[1, 2], [2, 1], [4, 3], [0, 0]],
[[2, 1], [3, 4], [0, 0], [0, 0]],
[[3, 5], [1, 3], [4, 6], [5, 2]]])
seq_length = [3, 2, 4]
b = tf.reverse_sequence(a, seq_length, 1, 0)
c = tf.reverse(a, axis=[1])
print(sess.run(b))
print(sess.run(c))
输出:
[[[4 3]
[2 1]
[1 2]
[0 0]]
[[3 4]
[2 1]
[0 0]
[0 0]]
[[5 2]
[4 6]
[1 3]
[3 5]]]
[[[0 0]
[4 3]
[2 1]
[1 2]]
[[0 0]
[0 0]
[3 4]
[2 1]]
[[5 2]
[4 6]
[1 3]
[3 5]]]
能看到reverse_sequence是只翻转实际长度,0放在最后,reverse是直接全部翻转。
所以,得到第一个结论,stack_bidirectional_dynamic_rnn、bidirecitonal_dynamic_rnn函数对sequence_length参数的处理在于使用了两个不同的翻转函数:reverse_sequence和reverse。
那么,sequence_length的输入与否会对stack_bidirectional_dynamic_rnn、bidirecitonal_dynamic_rnn的输出造成什么影响呢?
笔者看来,有两个方面对输出有潜在影响:
-
一个是正向序列梯度计算时,输入的不足最大长度的data,在计算到实际长度后,还会对后面padding为0的time_step进行计算,生成不必要的错误数据,使训练时长增加;
-
二则是由于是双向RNN,在反向序列计算中,reverse后,前面padding为0的time_step计算都是无用的,到了真正有数据时才开始计算。
现在还需要搞清楚,是否输入sequence_length是否会对计算结果准确度造成影响?
从基本的开始,先看sequence_length对dynamic_rnn函数输出的影响:
with tf.Session() as sess:
X = np.random.randn(2, 10, 8)
# 第二个example长度为6
X[1, 6:] = 0.0
# X[1, :6] = 0.0
X_lengths = [10, 6]
cell = tf.nn.rnn_cell.LSTMCell(num_units=5)
outputs, last_states = tf.nn.dynamic_rnn(
cell=cell,
dtype=tf.float64,
sequence_length=X_lengths,
inputs=X)
outputs1, last_states1 = tf.nn.dynamic_rnn(
cell=cell,
dtype=tf.float64,
inputs=X)
sess.run(tf.global_variables_initializer())
print(sess.run(outputs))
print(sess.run(outputs1))
输出:
[[[ 0.3664112 -0.16582295 0.03347532 -0.1209446 -0.16412497]
[ 0.35257143 -0.05686 0.03231346 -0.13748774 -0.19436866]
[ 0.53990554 -0.0629166 -0.01856082 -0.08934467 -0.19854734]
[ 0.62018885 -0.14675617 -0.23786601 -0.18285464 -0.30012293]
[ 0.29875151 -0.07205287 -0.09079972 -0.32260035 -0.07121483]
[ 0.12284296 0.05494323 -0.26602965 -0.31480686 -0.00965394]
[-0.04321202 0.10846491 -0.15792417 0.04715865 0.16775026]
[ 0.14991799 0.24498104 -0.22123176 -0.16537638 0.17049721]
[ 0.22538293 0.21274618 -0.30151976 -0.2806619 0.07987886]
[ 0.16169614 -0.09109676 -0.23507253 -0.2256854 0.08532138]]
[[ 0.03130918 -0.00153811 0.05989608 0.00258946 0.12982402]
[-0.02059972 0.24180047 -0.01174238 0.20337912 0.11303384]
[ 0.04525809 0.25172152 -0.12350006 0.09227578 0.25866815]
[ 0.13557124 0.3672627 -0.3223366 -0.05648763 0.20540585]
[ 0.18130031 0.20992908 -0.09175851 0.07085576 0.23343628]
[ 0.34051405 0.28743254 -0.10791713 -0.09506878 0.12702959]
[ 0. 0. 0. 0. 0. ]
[ 0. 0. 0. 0. 0. ]
[ 0. 0. 0. 0. 0. ]
[ 0. 0. 0. 0. 0. ]]]
[[[ 0.3664112 -0.16582295 0.03347532 -0.1209446 -0.16412497]
[ 0.35257143 -0.05686 0.03231346 -0.13748774 -0.19436866]
[ 0.53990554 -0.0629166 -0.01856082 -0.08934467 -0.19854734]
[ 0.62018885 -0.14675617 -0.23786601 -0.18285464 -0.30012293]
[ 0.29875151 -0.07205287 -0.09079972 -0.32260035 -0.07121483]
[ 0.12284296 0.05494323 -0.26602965 -0.31480686 -0.00965394]
[-0.04321202 0.10846491 -0.15792417 0.04715865 0.16775026]
[ 0.14991799 0.24498104 -0.22123176 -0.16537638 0.17049721]
[ 0.22538293 0.21274618 -0.30151976 -0.2806619 0.07987886]
[ 0.16169614 -0.09109676 -0.23507253 -0.2256854 0.08532138]]
[[ 0.03130918 -0.00153811 0.05989608 0.00258946 0.12982402]
[-0.02059972 0.24180047 -0.01174238 0.20337912 0.11303384]
[ 0.04525809 0.25172152 -0.12350006 0.09227578 0.25866815]
[ 0.13557124 0.3672627 -0.3223366 -0.05648763 0.20540585]
[ 0.18130031 0.20992908 -0.09175851 0.07085576 0.23343628]
[ 0.34051405 0.28743254 -0.10791713 -0.09506878 0.12702959]
[ 0.20726339 0.2773752 -0.18121164 -0.06436528 0.1766949 ]
[ 0.19274723 0.24523639 -0.14879227 -0.05048222 0.20608781]
[ 0.17663077 0.22035726 -0.12688077 -0.03431938 0.22335574]
[ 0.16139967 0.19805899 -0.11295674 -0.01820427 0.23194738]]]
对比可知,输入sequence_length情况下,实际长度之后的时间步不再计算。
如果翻转整个序列,前6个时间步为0,dynamic_rnn对应的时间步计算结果均为0。
X[1, :6] = 0.0
outputs1, last_states1 = tf.nn.dynamic_rnn(
cell=cell,
dtype=tf.float64,
inputs=X)
sess.run(tf.global_variables_initializer())
print(sess.run(outputs1))
输出:
[[[ 0.04069017 -0.23771831 0.045805 0.09560217 0.05502168]
[ 0.09332863 -0.04726055 0.11690663 0.08879598 0.25465747]
[ 0.19825957 -0.28982961 0.08313738 0.20992825 0.02033359]
[ 0.15471266 -0.56485913 0.05661117 0.20066938 -0.03992696]
[ 0.45740853 -0.0688219 0.34996705 0.07094942 -0.12920683]
[ 0.45969932 0.11180939 -0.07534739 -0.03005722 -0.02074522]
[ 0.25090551 0.03269829 -0.25112071 -0.28449897 -0.03996906]
[ 0.00671926 0.04019763 -0.01132277 -0.06963076 0.00314332]
[ 0.02276943 0.24375711 -0.18096459 0.10516443 -0.0245148 ]
[-0.07252719 0.21662108 -0.06866634 0.09010588 0.04821544]]
[[ 0. 0. 0. 0. 0. ]
[ 0. 0. 0. 0. 0. ]
[ 0. 0. 0. 0. 0. ]
[ 0. 0. 0. 0. 0. ]
[ 0. 0. 0. 0. 0. ]
[ 0. 0. 0. 0. 0. ]
[ 0.15381835 0.05551223 0.10847534 0.11666546 -0.10370024]
[ 0.15700091 -0.1763513 0.16971487 0.0940137 -0.08755241]
[-0.00314326 -0.16764532 0.13892345 0.05829588 -0.2285242 ]
[ 0.16221322 -0.24974877 -0.17015645 -0.08633168 -0.05848374]]]
现在看一下bidirectional_dynamic_rnn函数的输出:
with tf.Session() as sess:
cell = tf.nn.rnn_cell.LSTMCell(num_units=5)
outputs, last_states = tf.nn.bidirectional_dynamic_rnn(
cell_fw=cell,
cell_bw=cell,
dtype=tf.float64,
sequence_length=X_lengths,
inputs=X)
outputs1, last_states1 = tf.nn.bidirectional_dynamic_rnn(
cell_fw=cell,
cell_bw=cell,
dtype=tf.float64,
inputs=X)
sess.run(tf.global_variables_initializer())
print(sess.run(outputs))
print(sess.run(outputs1))
输出:
(array([[[ 1.03637534e-01, -7.10830000e-02, -6.15507404e-02,
1.82409063e-01, -1.27943658e-02],
[-1.71373530e-01, -1.63462122e-01, 5.82945984e-02,
2.30467173e-01, -5.47676139e-02],
[ 5.13349422e-02, -1.82497689e-01, -1.75145553e-01,
1.53907335e-01, -7.97362981e-02],
[ 1.06949741e-01, -4.01306522e-02, -2.10730327e-01,
1.09050034e-01, -2.32756766e-01],
[ 4.60590349e-02, -4.58901677e-02, -1.58516232e-01,
2.01659355e-02, 2.69986613e-02],
[ 1.34656088e-01, 1.27036696e-01, -1.17810220e-02,
-3.33886708e-02, 4.41047476e-01],
[ 2.22404726e-01, 1.59448161e-01, -1.01579624e-01,
-5.48369031e-02, 1.06230121e-01],
[ 2.30960285e-01, -3.90051280e-02, -2.24918456e-02,
-2.35480280e-02, 1.52811464e-02],
[ 2.10968050e-01, 1.83505109e-01, 1.78626573e-01,
-2.55444308e-01, 4.00898776e-01],
[ 2.58521311e-01, 3.29425438e-01, 1.17003287e-01,
-2.77698999e-02, 1.49407274e-01]],
[[-1.58240424e-01, -8.11458422e-02, 4.00896978e-02,
1.04382802e-01, -3.94471746e-03],
[-3.35112999e-01, -9.55060361e-02, 2.90027134e-01,
-1.81891039e-01, 9.10953666e-03],
[-1.67419074e-01, -4.00547833e-04, 1.95651535e-01,
-2.47421707e-01, 5.14707873e-02],
[-3.55542561e-02, 1.83676786e-01, 1.79457375e-01,
-4.30804504e-01, 8.66083426e-02],
[-2.09451275e-01, 2.34769980e-01, 3.52470628e-01,
-1.28957238e-01, 2.21493636e-01],
[-3.41874621e-01, -2.51958521e-02, 5.17341140e-01,
-3.61904359e-01, 2.32170058e-01],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00]]]), array([[[ 0.16441463, -0.19302047, -0.09601584, 0.26801333,
-0.03560607],
[ 0.17595892, -0.15906587, -0.04336751, 0.20887378,
-0.03159793],
[ 0.2235174 , -0.05801825, -0.17605939, 0.091533 ,
0.0131942 ],
[ 0.11473129, 0.09345286, 0.01537781, -0.01188187,
0.14568516],
[ 0.03106931, 0.08806389, 0.40831446, -0.10394683,
0.36716515],
[ 0.2329871 , 0.15022 , 0.11296073, -0.07147051,
0.4857348 ],
[ 0.20390796, 0.01588418, 0.03380659, -0.03090537,
0.09930256],
[ 0.24167838, -0.00490647, 0.16429447, -0.00820602,
0.08831998],
[ 0.23332149, 0.18020251, 0.12465346, -0.1750562 ,
0.40673199],
[ 0.15236341, -0.0623327 , -0.07389674, 0.0791574 ,
0.02982562]],
[[-0.36612451, -0.09818593, 0.37779096, -0.10673546,
0.09468116],
[-0.24289844, -0.00789502, 0.43728569, -0.36032157,
0.18526209],
[-0.05659099, 0.23216722, 0.28674557, -0.31059676,
0.15522032],
[-0.22365649, 0.18580737, 0.2019476 , -0.41131144,
0.19063189],
[-0.38062693, 0.02047546, 0.36591703, -0.11403186,
0.28162637],
[-0.22493386, -0.06232192, 0.41815278, -0.2991254 ,
0.1451596 ],
[ 0. , 0. , 0. , 0. ,
0. ],
[ 0. , 0. , 0. , 0. ,
0. ],
[ 0. , 0. , 0. , 0. ,
0. ],
[ 0. , 0. , 0. , 0. ,
0. ]]]))
(array([[[ 1.03637534e-01, -7.10830000e-02, -6.15507404e-02,
1.82409063e-01, -1.27943658e-02],
[-1.71373530e-01, -1.63462122e-01, 5.82945984e-02,
2.30467173e-01, -5.47676139e-02],
[ 5.13349422e-02, -1.82497689e-01, -1.75145553e-01,
1.53907335e-01, -7.97362981e-02],
[ 1.06949741e-01, -4.01306522e-02, -2.10730327e-01,
1.09050034e-01, -2.32756766e-01],
[ 4.60590349e-02, -4.58901677e-02, -1.58516232e-01,
2.01659355e-02, 2.69986613e-02],
[ 1.34656088e-01, 1.27036696e-01, -1.17810220e-02,
-3.33886708e-02, 4.41047476e-01],
[ 2.22404726e-01, 1.59448161e-01, -1.01579624e-01,
-5.48369031e-02, 1.06230121e-01],
[ 2.30960285e-01, -3.90051280e-02, -2.24918456e-02,
-2.35480280e-02, 1.52811464e-02],
[ 2.10968050e-01, 1.83505109e-01, 1.78626573e-01,
-2.55444308e-01, 4.00898776e-01],
[ 2.58521311e-01, 3.29425438e-01, 1.17003287e-01,
-2.77698999e-02, 1.49407274e-01]],
[[-1.58240424e-01, -8.11458422e-02, 4.00896978e-02,
1.04382802e-01, -3.94471746e-03],
[-3.35112999e-01, -9.55060361e-02, 2.90027134e-01,
-1.81891039e-01, 9.10953666e-03],
[-1.67419074e-01, -4.00547833e-04, 1.95651535e-01,
-2.47421707e-01, 5.14707873e-02],
[-3.55542561e-02, 1.83676786e-01, 1.79457375e-01,
-4.30804504e-01, 8.66083426e-02],
[-2.09451275e-01, 2.34769980e-01, 3.52470628e-01,
-1.28957238e-01, 2.21493636e-01],
[-3.41874621e-01, -2.51958521e-02, 5.17341140e-01,
-3.61904359e-01, 2.32170058e-01],
[-2.51546442e-01, -1.18885843e-02, 2.96923579e-01,
-1.86115662e-01, 2.16003435e-01],
[-2.14802264e-01, 7.02766447e-04, 2.61682611e-01,
-9.91503917e-02, 1.88783995e-01],
[-1.87142315e-01, 1.90646907e-03, 2.19077987e-01,
-3.69484125e-02, 1.55271455e-01],
[-1.63632591e-01, -6.82563058e-04, 1.78274888e-01,
2.33492655e-03, 1.24669109e-01]]]), array([[[ 0.16441463, -0.19302047, -0.09601584, 0.26801333,
-0.03560607],
[ 0.17595892, -0.15906587, -0.04336751, 0.20887378,
-0.03159793],
[ 0.2235174 , -0.05801825, -0.17605939, 0.091533 ,
0.0131942 ],
[ 0.11473129, 0.09345286, 0.01537781, -0.01188187,
0.14568516],
[ 0.03106931, 0.08806389, 0.40831446, -0.10394683,
0.36716515],
[ 0.2329871 , 0.15022 , 0.11296073, -0.07147051,
0.4857348 ],
[ 0.20390796, 0.01588418, 0.03380659, -0.03090537,
0.09930256],
[ 0.24167838, -0.00490647, 0.16429447, -0.00820602,
0.08831998],
[ 0.23332149, 0.18020251, 0.12465346, -0.1750562 ,
0.40673199],
[ 0.15236341, -0.0623327 , -0.07389674, 0.0791574 ,
0.02982562]],
[[-0.36612451, -0.09818593, 0.37779096, -0.10673546,
0.09468116],
[-0.24289844, -0.00789502, 0.43728569, -0.36032157,
0.18526209],
[-0.05659099, 0.23216722, 0.28674557, -0.31059676,
0.15522032],
[-0.22365649, 0.18580737, 0.2019476 , -0.41131144,
0.19063189],
[-0.38062693, 0.02047546, 0.36591703, -0.11403186,
0.28162637],
[-0.22493386, -0.06232192, 0.41815278, -0.2991254 ,
0.1451596 ],
[ 0. , 0. , 0. , 0. ,
0. ],
[ 0. , 0. , 0. , 0. ,
0. ],
[ 0. , 0. , 0. , 0. ,
0. ],
[ 0. , 0. , 0. , 0. ,
0. ]]]))
从输出能看到,对长度为10的序列,sequence_length参数不会造成影响。
而长度为6的第二个序列,输入了sequence_length后,前向状态传播输出的后面padding的时间步为0,没传参的后面就还在计算。
根据函数源码
output_bw = _reverse(
tmp, seq_lengths=sequence_length,
seq_axis=time_axis, batch_axis=batch_axis)
可知,后向状态传播的输出结果相同,与实验结果一致。
终于到了stack_bidirectional_dynamic_rnn函数:
def lstm_cell(lstm_unit=5):
cell = tf.nn.rnn_cell.LSTMCell(num_units=lstm_unit)
return cell
with tf.Session() as sess:
# 两层双向LSTM
cells_fw = tf.contrib.rnn.MultiRNNCell([lstm_cell() for _ in range(2)])
cells_bw = tf.contrib.rnn.MultiRNNCell([lstm_cell() for _ in range(2)])
outputs, output_state_fw, output_state_bw = \
tf.contrib.rnn.stack_bidirectional_dynamic_rnn(
cells_fw=[cells_fw], cells_bw=[cells_bw], inputs=X, sequence_length=X_lengths, dtype=tf.float64)
outputs1, output_state_fw1, output_state_bw1 = \
tf.contrib.rnn.stack_bidirectional_dynamic_rnn(
cells_fw=[cells_fw], cells_bw=[cells_bw], inputs=X, dtype=tf.float64)
sess.run(tf.global_variables_initializer())
print(sess.run(outputs))
print(sess.run(outputs1))
输出:
[[[-1.28601883e-02 -1.22699760e-04 2.87583047e-02 -1.12176756e-02
-8.02404394e-03 -5.23670159e-02 1.56707537e-01 1.06681455e-01
2.62892937e-02 -1.14040429e-01]
[-3.43349540e-02 -2.66737734e-03 6.48235389e-02 -1.54243659e-02
-7.99647407e-03 -7.89364081e-02 1.52510327e-01 1.01263909e-01
2.12051295e-02 -9.87228364e-02]
[-7.06332778e-02 -1.31020952e-02 1.07337490e-01 -3.73839672e-02
-1.85424702e-02 -9.77438210e-02 1.35022125e-01 5.90582636e-02
1.63908052e-03 -7.43389943e-02]
[-8.23076663e-02 -3.82813864e-02 1.37970988e-01 -2.99934549e-02
-2.68845690e-02 -1.02355369e-01 9.12888953e-02 -3.53093145e-04
-3.53520397e-02 -5.78286482e-02]
[-9.50498624e-02 -4.60312526e-02 1.67575362e-01 -2.74771200e-02
-8.43013199e-03 -8.06010196e-02 5.54779549e-02 -4.75758513e-02
-6.47748359e-02 -2.75317154e-02]
[-7.16351738e-02 -4.41277788e-02 1.45701763e-01 1.24855432e-03
1.46169942e-02 -7.60465568e-02 5.93749022e-02 -6.17292026e-02
-5.86224445e-02 -2.66687336e-02]
[-6.16778231e-02 -4.72333447e-02 1.11211251e-01 3.88144082e-02
4.30879841e-02 -3.65799040e-02 5.09810184e-02 -3.84721677e-02
-2.77666367e-02 -1.34022960e-02]
[-4.86744743e-02 -2.13668975e-02 6.73850765e-02 2.83607342e-02
3.75809411e-02 -2.51348881e-02 2.83121865e-02 -2.06509712e-02
-2.13905453e-02 -4.74981693e-03]
[-3.49021170e-02 -1.49223217e-02 5.07296592e-02 3.44248124e-02
1.76991886e-02 -8.80179158e-03 -8.46805269e-03 -2.06985742e-02
-3.19633058e-02 9.95127577e-03]
[-1.98075939e-02 -5.06471247e-03 2.12698655e-02 3.85639213e-02
2.10663157e-02 -1.11487135e-02 -7.10696841e-03 -6.24988002e-03
-1.48659066e-02 6.16684391e-03]]
[[ 3.62599737e-03 2.94539921e-03 -2.58042726e-02 6.62572147e-03
8.17662302e-03 3.21574858e-02 8.00241626e-02 2.88561398e-03
2.25625696e-02 -2.97773596e-02]
[ 1.24748018e-02 3.25943748e-02 -6.98699079e-02 -1.99415204e-02
8.98645467e-03 2.75950596e-02 6.77296140e-02 -1.72919213e-03
2.37468487e-02 -3.50077631e-02]
[ 2.12325089e-02 3.92948087e-02 -9.18501935e-02 -2.87299468e-02
-3.20147555e-04 1.46104383e-02 7.63723852e-02 1.26062855e-02
2.64755854e-02 -2.93706603e-02]
[ 2.27776689e-02 4.12789392e-02 -8.49281608e-02 -3.98009553e-02
-2.13557998e-02 -1.36286022e-03 4.99668684e-02 2.34237759e-02
2.02189041e-02 -2.05926547e-02]
[ 1.99018813e-02 3.41225257e-02 -8.66679609e-02 -3.45312179e-02
-1.78512544e-02 -4.82369387e-03 2.14085347e-02 3.65739249e-02
9.21655481e-03 -1.30496637e-03]
[ 2.62904309e-02 6.18586705e-02 -9.56436806e-02 -6.58077095e-02
-2.72908745e-02 4.50642484e-03 1.30500698e-02 1.51553697e-02
1.34644969e-02 -5.49530647e-03]
[ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]]]
[[[-1.28601883e-02 -1.22699760e-04 2.87583047e-02 -1.12176756e-02
-8.02404394e-03 -5.23670159e-02 1.56707537e-01 1.06681455e-01
2.62892937e-02 -1.14040429e-01]
[-3.43349540e-02 -2.66737734e-03 6.48235389e-02 -1.54243659e-02
-7.99647407e-03 -7.89364081e-02 1.52510327e-01 1.01263909e-01
2.12051295e-02 -9.87228364e-02]
[-7.06332778e-02 -1.31020952e-02 1.07337490e-01 -3.73839672e-02
-1.85424702e-02 -9.77438210e-02 1.35022125e-01 5.90582636e-02
1.63908052e-03 -7.43389943e-02]
[-8.23076663e-02 -3.82813864e-02 1.37970988e-01 -2.99934549e-02
-2.68845690e-02 -1.02355369e-01 9.12888953e-02 -3.53093145e-04
-3.53520397e-02 -5.78286482e-02]
[-9.50498624e-02 -4.60312526e-02 1.67575362e-01 -2.74771200e-02
-8.43013199e-03 -8.06010196e-02 5.54779549e-02 -4.75758513e-02
-6.47748359e-02 -2.75317154e-02]
[-7.16351738e-02 -4.41277788e-02 1.45701763e-01 1.24855432e-03
1.46169942e-02 -7.60465568e-02 5.93749022e-02 -6.17292026e-02
-5.86224445e-02 -2.66687336e-02]
[-6.16778231e-02 -4.72333447e-02 1.11211251e-01 3.88144082e-02
4.30879841e-02 -3.65799040e-02 5.09810184e-02 -3.84721677e-02
-2.77666367e-02 -1.34022960e-02]
[-4.86744743e-02 -2.13668975e-02 6.73850765e-02 2.83607342e-02
3.75809411e-02 -2.51348881e-02 2.83121865e-02 -2.06509712e-02
-2.13905453e-02 -4.74981693e-03]
[-3.49021170e-02 -1.49223217e-02 5.07296592e-02 3.44248124e-02
1.76991886e-02 -8.80179158e-03 -8.46805269e-03 -2.06985742e-02
-3.19633058e-02 9.95127577e-03]
[-1.98075939e-02 -5.06471247e-03 2.12698655e-02 3.85639213e-02
2.10663157e-02 -1.11487135e-02 -7.10696841e-03 -6.24988002e-03
-1.48659066e-02 6.16684391e-03]]
[[ 3.62599737e-03 2.94539921e-03 -2.58042726e-02 6.62572147e-03
8.17662302e-03 3.21574858e-02 8.00241626e-02 2.88561398e-03
2.25625696e-02 -2.97773596e-02]
[ 1.24748018e-02 3.25943748e-02 -6.98699079e-02 -1.99415204e-02
8.98645467e-03 2.75950596e-02 6.77296140e-02 -1.72919213e-03
2.37468487e-02 -3.50077631e-02]
[ 2.12325089e-02 3.92948087e-02 -9.18501935e-02 -2.87299468e-02
-3.20147555e-04 1.46104383e-02 7.63723852e-02 1.26062855e-02
2.64755854e-02 -2.93706603e-02]
[ 2.27776689e-02 4.12789392e-02 -8.49281608e-02 -3.98009553e-02
-2.13557998e-02 -1.36286022e-03 4.99668684e-02 2.34237759e-02
2.02189041e-02 -2.05926547e-02]
[ 1.99018813e-02 3.41225257e-02 -8.66679609e-02 -3.45312179e-02
-1.78512544e-02 -4.82369387e-03 2.14085347e-02 3.65739249e-02
9.21655481e-03 -1.30496637e-03]
[ 2.62904309e-02 6.18586705e-02 -9.56436806e-02 -6.58077095e-02
-2.72908745e-02 4.50642484e-03 1.30500698e-02 1.51553697e-02
1.34644969e-02 -5.49530647e-03]
[ 3.05670502e-02 6.18263175e-02 -7.87989639e-02 -7.46175202e-02
-3.49183715e-02 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 3.33352693e-02 5.59230982e-02 -6.25883387e-02 -7.42776119e-02
-3.54990111e-02 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 3.38701860e-02 4.70675661e-02 -4.76290537e-02 -6.78331036e-02
-3.13554307e-02 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 3.23695247e-02 3.73589351e-02 -3.50100669e-02 -5.80015682e-02
-2.49360807e-02 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]]]
可以看到输入sequence_length参数的stack_bidirectional_dynamic_rnn函数输出的padding部分计算结果为0,而没输入实际长度参数的输出,padding部分仍然有值。
通过以上实验,总算是弄清楚了sequence_length参数对于stack_bidirectional_dynamic_rnn、bidirectional_dynamic_rnn、dynamic_rnn函数结果的影响。虽然函数注释中都写的是optional,但个人感觉最好传入这个实际长度参数,避免可能的错误和中间不必要的计算。至于是否最终结果,要根据后续的具体代码处理而定。
后面有时间再看一下sequence_length参数对prediction和loss的影响。
作为一名很菜的程序媛,在CSDN上学到了很多很实用的知识技术,这是本人的第一篇博客,以后也要努力把敲代码过程中遇到的问题记录下来,提高coding能力。