最近在看tensorflow中seq2seq的实现,对 rnn 返回的outputs和final_states 有点疑惑。按照我的理解,对于RNN 返回的outputs 是整个序列在time 上的展开,即 outputs = [h1,⋯,hT] [ h 1 , ⋯ , h T ] ,而 final_states 应该为 hT h T , 但是如果这样的话显然不合适,因为返回一个outputs就可以了,那么 final_states 会不会是最后那个cell单元。看代码BasicLSTMCell中有个参数state_is_tuple,按照官网解释,state_is_tuple=True 表示返回的是 (cT,hT) ( c T , h T ) , 如果state_is_tuple=False表示返回的是 [cT;hT] [ c T ; h T ] ,表示 cT c T 和 hT h T 的拼接。
import tensorflow as tf
import numpy as np
# 创建输入数据
X = np.random.randn(2, 10, 8)
# 第二个example长度为6
X[1, 6:] = 0
X_lengths = [10, 6]
cell = tf.contrib.rnn.BasicLSTMCell(num_units=16, state_is_tuple=True)
outputs, last_states = tf.nn.dynamic_rnn(
cell=cell,
dtype=tf.float64,
sequence_length=X_lengths,
inputs=X)
result = tf.contrib.learn.run_n(
{"outputs": outputs, "final_states": last_states},
n=1,
feed_dict=None)
print(result[0])
输出结果为
{'outputs': array([[[ 8.51600608e-02, 1.20660209e-01, -8.30597066e-02,
8.94003039e-02, -4.79661940e-02, 6.33470364e-02,
-1.56834621e-01, 4.85511425e-02, 3.02277184e-02,
-8.24393029e-02, -1.08487943e-01, -5.45718743e-02,
1.06717348e-01, -3.63896366e-02, 8.62340956e-02,
1.00542264e-01],
[ 5.26434674e-02, 8.82528803e-02, 8.18372465e-02,
-4.70024490e-04, -1.12268602e-01, 7.81389471e-02,
1.56262981e-01, 1.37541835e-01, 4.07896288e-03,
3.40363337e-02, -1.06232519e-01, -3.45220435e-03,
4.45592240e-02, -2.14986911e-01, -5.29998580e-02,
-4.47588031e-02],
[ -1.15412606e-01, -3.15060562e-02, -3.53337836e-02,
-1.41307868e-01, -2.03983857e-01, 1.52195883e-01,
5.93960067e-02, 5.22762772e-02, 6.51036954e-02,
1.62629488e-02, -6.71094027e-02, 1.78651720e-01,
1.43935632e-02, -8.98583193e-02, 8.51871234e-02,
1.12750453e-01],
[ -1.38490492e-01, 5.49294021e-02, -5.79125883e-02,
-1.34913379e-02, -3.15632788e-02, 1.95746132e-01,
4.04904864e-02, 1.36532347e-01, 8.69326328e-02,
1.48219807e-01, -6.75214799e-02, 6.74295126e-02,
-7.10565155e-02, -1.45115152e-01, 5.61782256e-02,
-2.67229487e-02],
[ -1.33728334e-01, 6.50316413e-03, -8.28467300e-02,
-2.13676304e-02, 1.99136425e-01, 2.07323407e-01,
-7.04399740e-02, 6.64707052e-02, 1.28669041e-01,
-3.99008089e-03, -7.76149784e-02, 4.61886935e-02,
-6.50376977e-02, -1.18495263e-01, 1.17621248e-01,
3.39481317e-02],
[ -1.17954995e-01, 7.39067638e-02, -4.73444202e-02,
7.34059918e-02, 1.13222390e-01, 2.53695281e-01,
-1.24950329e-01, 2.24618668e-02, 6.06518076e-02,
5.96379696e-02, -1.24584380e-01, 3.29196227e-02,
-1.01975574e-01, 7.47888395e-02, 3.75111982e-02,
-2.45382948e-03],
[ -1.99299948e-01, -7.39883224e-02, 9.04839265e-02,
2.56283750e-02, 2.80005345e-01, 3.34384495e-01,
-4.94238377e-03, 1.31329177e-01, 5.09313718e-02,
2.07172398e-01, -1.27001974e-01, 2.07658264e-01,
-2.35004803e-01, 9.10458478e-02, -3.33668414e-02,
-4.22343926e-02],
[ -6.93773303e-02, -8.29434270e-02, 1.74555290e-01,
4.32463735e-02, 1.92333551e-01, 8.05206310e-02,
-4.67456180e-02, 2.04428783e-01, 5.42508192e-02,
-1.78346205e-02, -8.51749327e-02, 1.76002661e-01,
-1.78980937e-01, 9.33172891e-02, -7.65833685e-02,
5.22866062e-02],
[ -1.28708341e-01, -1.07299070e-01, 1.74460808e-02,
-1.41812305e-01, -1.75518783e-01, 8.41806430e-02,
8.74052786e-02, 8.54392758e-02, 8.94714372e-02,
-7.60687913e-02, 5.32559011e-04, 2.12149332e-01,
-1.76561832e-01, 1.10551504e-01, 5.45956606e-02,
2.27662777e-01],
[ -1.38716746e-01, 3.50782336e-02, 2.44049428e-02,
-6.19580995e-02, -1.22048162e-01, -6.71233156e-03,
1.49210897e-01, 2.68549670e-02, 4.08326486e-02,
-3.16653869e-02, 1.05464326e-01, 1.64826060e-01,
-1.17569311e-01, 1.11453975e-01, -8.88766093e-02,
7.84726110e-03]],
[[ -5.85491016e-02, 1.69248395e-01, -1.67613134e-01,
2.40518851e-01, 5.23920082e-02, 9.71959834e-02,
-2.93255173e-02, -8.00785565e-02, 5.00090944e-02,
-6.95419630e-02, 2.68765224e-02, -1.63707918e-01,
6.18403061e-02, -8.36852013e-02, 2.69253483e-01,
-7.63595695e-02],
[ -3.14072350e-02, 2.39703454e-01, -2.38231502e-01,
9.26873152e-02, 8.10852132e-02, 1.27162577e-01,
-5.73377612e-02, -1.39968334e-01, 3.99482735e-02,
-8.47072368e-02, 1.83764730e-02, -2.14926501e-01,
1.06305866e-04, 8.34569414e-03, 2.06806184e-01,
-7.32392374e-02],
[ -1.93460859e-02, 6.24097978e-02, -1.03516662e-01,
-1.26342672e-02, -2.08758636e-02, 1.66454528e-01,
-5.81779960e-02, -9.20345358e-02, -8.29544111e-03,
-1.91938126e-03, -8.22920924e-02, -5.05529121e-02,
-1.58349559e-02, 3.58339902e-02, 8.59197276e-02,
2.68994241e-02],
[ 1.21182823e-01, -1.60383966e-01, 5.01372858e-02,
-1.05416048e-01, -1.13352303e-01, 7.64203395e-03,
-8.72742212e-02, -2.58785716e-03, -6.08090481e-02,
-1.33383746e-02, -1.75964119e-01, 4.07472718e-02,
-1.35161904e-03, 2.60193309e-02, -2.81093680e-02,
2.01884647e-01],
[ 1.40900464e-01, -9.81550810e-02, 3.84887706e-02,
-1.17454971e-01, -2.05570483e-01, -1.04132781e-01,
-3.25062432e-02, -4.22595784e-02, -7.68583127e-02,
-1.44954773e-01, -4.91470908e-02, -2.19290014e-03,
5.30027149e-02, 4.02266622e-02, -2.62458365e-02,
1.97159511e-01],
[ 6.42759069e-03, -1.79224216e-01, 5.47793212e-02,
-1.82429195e-01, -2.33649002e-01, -2.23675583e-02,
6.23475515e-02, -7.94249204e-02, -4.07472203e-02,
-1.65480823e-01, 4.92744710e-02, 1.17531775e-01,
-6.34366232e-02, 1.08770149e-01, 1.52727682e-02,
1.51494894e-01],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00]]]), 'final_states': LSTMStateTuple(c=array([[-0.28682583, 0.07014927, 0.04587004, -0.137129 , -0.32802295,
-0.01273347, 0.34420444, 0.05639795, 0.09310471, -0.05722055,
0.186892 , 0.36725005, -0.28890884, 0.23256402, -0.13752357,
0.01838055],
[ 0.01344121, -0.35903624, 0.1086343 , -0.38516179, -0.43465922,
-0.04556597, 0.14225783, -0.18317949, -0.08140523, -0.29619229,
0.07555464, 0.22769443, -0.10944077, 0.23001641, 0.04418935,
0.28007862]]), h=array([[-0.13871675, 0.03507823, 0.02440494, -0.0619581 , -0.12204816,
-0.00671233, 0.1492109 , 0.02685497, 0.04083265, -0.03166539,
0.10546433, 0.16482606, -0.11756931, 0.11145397, -0.08887661,
0.00784726],
[ 0.00642759, -0.17922422, 0.05477932, -0.18242919, -0.233649 ,
-0.02236756, 0.06234755, -0.07942492, -0.04074722, -0.16548082,
0.04927447, 0.11753178, -0.06343662, 0.10877015, 0.01527277,
0.15149489]]))}
从上面的结果中可以看到,final_states 分别为
(c9,h9),(c5,h5)
(
c
9
,
h
9
)
,
(
c
5
,
h
5
)
,因为第二个样本的time-step=6,所以才有
(c5,h5)
(
c
5
,
h
5
)
.
将 state_is_tuple=False 的运行结果如下
{'outputs': array([[[-0.0520232 , -0.00596184, -0.15686908, -0.12699241, -0.0998554 ,
-0.13780395, 0.12679938, 0.13057581, 0.01394446, -0.03188085,
0.13801467, -0.10640567, 0.00101057, 0.11505147, 0.14934653,
0.00051895],
[ 0.0188593 , -0.09666485, -0.22422633, -0.19876283, -0.14657046,
-0.03839021, 0.23788255, -0.05810481, 0.17466428, 0.05607761,
0.01819319, -0.11747983, -0.00355588, 0.05552798, 0.21469954,
0.02979367],
[-0.07972986, -0.09420505, -0.24408267, -0.18565055, -0.09185143,
-0.0095228 , 0.23303042, -0.00075869, 0.12077442, 0.04463999,
-0.04770715, -0.06327784, 0.045044 , 0.07344893, 0.24777579,
-0.00886227],
[ 0.00519919, -0.02781198, 0.05779533, -0.00535419, 0.09587848,
0.11860488, 0.06507637, -0.06036546, 0.1939022 , 0.15366904,
-0.05568861, 0.06830101, 0.1332387 , -0.02283672, 0.0546639 ,
0.12544396],
[ 0.02893471, -0.08692809, 0.03823236, 0.07645208, 0.07632264,
0.19249265, 0.04535065, -0.12270601, 0.16775191, 0.1155008 ,
-0.08539387, 0.00234452, -0.04358566, -0.05436525, -0.02478817,
0.2553227 ],
[ 0.09466621, -0.03739176, 0.01791674, 0.12777455, 0.15859402,
0.0646465 , -0.0835708 , -0.17655872, 0.10373822, 0.14420393,
0.03712749, -0.00675274, -0.18880326, -0.17735889, 0.0100914 ,
0.19743785],
[ 0.15431932, -0.05348529, 0.05334238, 0.16724173, 0.18007884,
0.03453815, -0.1851201 , -0.14165195, 0.00732323, 0.09458506,
0.10690102, 0.02041647, -0.241884 , -0.2522594 , -0.00689178,
0.1445705 ],
[ 0.24555932, 0.09540935, 0.14411261, 0.34043302, 0.12379964,
-0.1129966 , -0.19834936, -0.06646834, 0.12707064, 0.16486904,
0.040797 , 0.00081681, -0.19858868, -0.2392037 , 0.05258687,
-0.18024845],
[ 0.08527663, -0.01553275, -0.04132372, 0.04184508, -0.04872827,
-0.18967336, -0.07614696, 0.19419587, -0.04699756, -0.04536202,
0.18612597, -0.03260128, -0.10163605, -0.15020102, 0.13656208,
-0.07045483],
[ 0.00902652, 0.00802411, 0.00349441, 0.0066363 , 0.02812586,
-0.2299042 , 0.00387813, 0.24113983, -0.09704951, -0.10187069,
0.13439021, -0.01243524, -0.08960348, -0.10325231, 0.15756624,
-0.29002312]],
[[-0.05148575, 0.06720759, -0.07206267, -0.06812521, 0.03614493,
0.04767764, -0.09072717, 0.13137064, -0.16131891, -0.23117387,
0.0621554 , -0.09076673, 0.1179775 , 0.32040227, -0.05554256,
0.09928288],
[-0.09680876, 0.16121176, -0.13065313, -0.15421254, 0.07431257,
-0.04082315, -0.02974331, 0.29206754, -0.15111419, -0.29595181,
0.13282585, -0.06019135, 0.10809174, 0.37644085, -0.0440165 ,
0.00685254],
[-0.03421361, 0.13748931, -0.09908259, -0.22864179, -0.00654961,
-0.0785063 , -0.02805491, 0.10213422, -0.13717202, -0.06699498,
0.17477517, 0.0122198 , 0.06943023, 0.44624913, -0.05628462,
0.08216506],
[ 0.01781805, 0.28241267, -0.05246943, -0.08493888, -0.02540539,
-0.13000158, -0.03256825, -0.0878165 , -0.01878073, 0.26896141,
0.13264408, 0.12267464, 0.0242984 , 0.40396969, -0.12822519,
0.15831019],
[ 0.08107763, 0.12859033, -0.05616897, -0.03728299, -0.0312746 ,
-0.08662511, -0.02553875, -0.0985618 , 0.0650238 , 0.03627828,
0.12857739, 0.02939889, -0.03254754, 0.18302167, -0.0657735 ,
0.12831271],
[ 0.06255597, -0.00889141, -0.10634166, -0.09577743, -0.08439649,
-0.12606134, 0.0083995 , -0.12306813, -0.03984412, -0.09111913,
0.10249516, -0.05978723, -0.16387474, 0.17921828, 0.00099589,
0.13509734],
[ 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. ],
[ 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. ],
[ 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. ],
[ 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. ]]]), 'final_states': array([[ 0.01704464, 0.01375266, 0.00569932, 0.01359386, 0.0599031 ,
-0.5310615 , 0.01106043, 0.63132237, -0.19392737, -0.22410165,
0.24611573, -0.04944791, -0.21992734, -0.26848952, 0.28565425,
-0.59903395, 0.00902652, 0.00802411, 0.00349441, 0.0066363 ,
0.02812586, -0.2299042 , 0.00387813, 0.24113983, -0.09704951,
-0.10187069, 0.13439021, -0.01243524, -0.08960348, -0.10325231,
0.15756624, -0.29002312],
[ 0.12148278, -0.01844205, -0.3399909 , -0.23447282, -0.18236695,
-0.18892043, 0.01784129, -0.27405052, -0.06386667, -0.22924761,
0.16695596, -0.11379451, -0.25193653, 0.46508263, 0.00269928,
0.34539569, 0.06255597, -0.00889141, -0.10634166, -0.09577743,
-0.08439649, -0.12606134, 0.0083995 , -0.12306813, -0.03984412,
-0.09111913, 0.10249516, -0.05978723, -0.16387474, 0.17921828,
0.00099589, 0.13509734]])}
采用双向BiLSTM
outputs, last_states = tf.nn.bidirectional_dynamic_rnn(
cell_fw=cell,
cell_bw=cell,
inputs=X,
sequence_length=X_lengths,
dtype=tf.float64,
time_major=False
)
运行结果如下
{'outputs': (array([[[ -8.41238987e-02, 2.02342451e-01, 2.60368436e-02,
-6.91997149e-02, -1.00815477e-02, -1.84710375e-02,
4.73017362e-02, -1.46565564e-01, -3.46681283e-02,
1.13812954e-01, 1.12076022e-01, 9.64718047e-02,
5.89444571e-02, 1.13485622e-02, -5.93493561e-02,
7.73265818e-02],
[ 1.90049974e-03, 3.13072057e-01, -7.36657193e-02,
5.85418739e-02, 1.73268038e-01, 6.05968861e-02,
7.95394467e-02, -4.45080027e-02, 5.93563046e-02,
5.73828147e-02, 1.05410051e-01, 8.10913136e-02,
7.32087617e-03, -1.65634074e-02, -3.33619727e-03,
-3.27739570e-02],
[ 1.17158311e-01, 7.44036804e-02, -3.29324582e-03,
9.05722883e-02, 8.89695327e-02, 1.12105162e-01,
1.00975061e-01, 5.15468347e-02, 5.14357339e-02,
2.06787209e-02, -2.83383422e-02, 2.11286105e-02,
3.43771341e-02, -7.92418082e-02, 8.53362033e-02,
-8.16963366e-02],
[ -4.25550595e-02, -7.65800479e-04, 1.74860287e-01,
9.02672757e-02, 4.04515710e-02, 1.11408791e-01,
1.81697921e-01, 2.39430182e-02, -8.38633094e-03,
1.71085707e-01, 4.48776829e-02, 8.19355270e-02,
1.65336573e-01, -1.86586163e-01, 1.40200445e-01,
-4.97211731e-02],
[ 9.68253056e-02, 6.06495878e-02, 2.08696567e-01,
3.87971635e-02, 5.61433717e-02, 3.35399408e-02,
1.11537949e-01, 4.59420685e-02, 1.41582280e-02,
7.49530534e-03, 1.55239730e-02, -4.05773611e-02,
1.58862006e-01, -2.09700998e-01, -2.13514475e-03,
-3.57567483e-03],
[ 1.32149638e-01, 9.30745932e-02, 2.46366699e-01,
9.27931858e-02, 1.09161202e-01, 8.44090243e-02,
6.41185248e-02, 6.26056890e-02, 7.41848483e-03,
-5.86925305e-02, 3.52253794e-02, -3.67522778e-02,
9.78261676e-02, -1.45995336e-01, -6.36220468e-02,
6.23390181e-02],
[ 2.32111735e-01, -2.37044221e-02, 1.77017646e-01,
-3.97338283e-04, 1.67123262e-03, 7.70140793e-02,
6.53800462e-02, 1.22855548e-02, 2.31961441e-02,
-1.26319736e-01, 6.41205262e-02, -1.53632481e-01,
8.28545616e-02, -1.20705745e-01, -4.10238845e-02,
9.20385553e-02],
[ 2.66012734e-01, 1.05477379e-01, 4.21216468e-01,
1.70095562e-01, -4.77457867e-02, 1.68168287e-01,
8.39350881e-02, -3.55770945e-02, -7.47878387e-02,
-2.45388387e-01, 1.18436625e-01, -7.06543538e-02,
1.85329579e-01, -1.84038439e-01, 7.32075603e-02,
9.42068544e-02],
[ 7.95414388e-02, 1.16495551e-01, 1.97109709e-01,
1.37920460e-02, -2.43955191e-03, 1.06759844e-01,
7.55164241e-02, -1.15271431e-01, -6.19764104e-02,
-1.80050690e-01, 1.04730874e-01, -9.92265522e-03,
1.36772531e-01, -9.17064895e-02, -8.42246795e-02,
1.01567247e-01],
[ 1.68130664e-01, 1.21676118e-01, 3.33891988e-02,
-3.62197881e-02, -2.31209753e-02, -4.31261261e-02,
3.24759979e-03, 2.58159477e-02, 7.60624235e-02,
-1.77180028e-01, -1.77196301e-02, -8.36235778e-02,
3.21666168e-03, -6.89595460e-03, 2.80242017e-02,
1.99551342e-02]],
[[ 1.56593441e-02, -3.41630711e-02, 2.05353309e-02,
-8.23938793e-02, -9.61087545e-02, -1.14604591e-01,
2.42885439e-02, 3.98282281e-02, -2.66879271e-02,
-8.42482121e-02, -2.03660879e-01, -1.11103103e-01,
-1.72252207e-01, 4.00229824e-02, 1.23120260e-01,
-1.10854331e-01],
[ 5.32126508e-02, -1.04440576e-01, -9.68604478e-02,
1.63228590e-02, 1.01150650e-01, -1.02453878e-02,
-1.33852475e-03, 1.64967070e-01, -8.67544927e-03,
-7.64610844e-02, -3.54521460e-01, -3.66237321e-02,
-2.45614683e-01, 4.51928160e-02, 6.81918975e-02,
-1.47578191e-01],
[ -8.66380079e-02, -1.01205470e-01, 1.25281612e-01,
8.15776305e-02, 1.32941869e-01, 7.84578282e-02,
-2.97582357e-02, 1.60467739e-01, -1.51084771e-01,
-6.00740157e-02, -2.50723724e-01, 1.69384964e-01,
2.09864500e-02, -1.09103700e-01, 1.35223647e-01,
-5.26320391e-02],
[ -3.79775701e-02, -9.84794905e-03, 1.04281425e-01,
2.18962102e-01, 9.49885755e-02, 9.30334434e-02,
1.20880917e-02, 4.92450967e-02, -1.39313910e-01,
-1.67242457e-01, -8.92664684e-02, 5.76337605e-02,
1.06973235e-01, -1.88442473e-01, 1.16279010e-01,
-6.63430587e-02],
[ 1.13913179e-01, -7.01235476e-02, 8.72963709e-02,
1.04267413e-01, -5.76935594e-02, 3.81271792e-02,
4.09102123e-03, 7.76470917e-02, -6.38018038e-03,
-2.10545446e-01, -3.69164679e-02, -1.20748943e-01,
1.32309592e-01, -1.45925061e-01, 1.47558076e-01,
-7.00836697e-02],
[ 4.32247065e-02, -6.83892073e-02, 7.56469023e-03,
-6.24281552e-02, 2.80944091e-02, 1.22108990e-02,
-4.95889882e-04, 1.40971813e-01, 1.07006212e-01,
-1.11817313e-01, -2.12919101e-01, -1.15743608e-01,
-6.11031704e-02, -2.48300200e-02, 7.24766670e-03,
-7.83430451e-02],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00]]]), array([[[ 0.02356511, 0.24838725, 0.17932081, 0.05942439, 0.15138192,
0.12581299, 0.07806339, -0.04999007, 0.02852398, -0.01691062,
0.12604886, 0.03830202, 0.23902503, -0.10546712, 0.06509197,
0.08460939],
[ 0.14112674, 0.04603104, 0.16150333, 0.12562632, 0.16140194,
0.13926239, 0.06831767, 0.13488513, 0.06390905, -0.09083964,
0.03295962, -0.01229307, 0.21374753, -0.10903434, 0.11961082,
-0.03469615],
[ 0.25343172, -0.14971647, 0.2810062 , 0.03923167, -0.01272575,
0.07319106, 0.06496102, 0.04369333, -0.00583105, -0.14490683,
0.04070688, -0.11250917, 0.24891372, -0.18076371, 0.10464212,
0.05199773],
[ 0.17880187, -0.04316056, 0.30915784, -0.03128962, 0.0136552 ,
0.00114268, 0.11112965, -0.02416813, -0.01660347, -0.18926507,
0.12892229, -0.09453105, 0.18057936, -0.17712924, 0.04842806,
0.13493346],
[ 0.2715969 , 0.02844215, 0.17219262, -0.04691747, 0.07650773,
-0.01564628, 0.05903233, 0.00140291, 0.03676231, -0.29671783,
0.05205306, -0.10022505, -0.00608728, -0.11042412, -0.0303516 ,
0.05871938],
[ 0.17246075, -0.0050864 , 0.13435615, 0.044755 , 0.05607737,
0.07413086, 0.09217622, -0.03587575, 0.02821449, -0.23358441,
0.08262469, -0.05006317, 0.02842524, -0.0650939 , 0.0275634 ,
0.07439376],
[ 0.18414017, -0.01139999, 0.07430255, -0.01925044, -0.02509538,
0.04357774, 0.1246182 , -0.0357719 , 0.02225083, -0.28293001,
0.08652195, -0.08315333, 0.03829619, -0.06317638, 0.06140839,
0.00854522],
[ 0.11787692, 0.16528217, 0.0742104 , 0.09235664, 0.10792024,
0.05196128, 0.09740701, -0.02470125, -0.03088181, -0.220567 ,
0.03927874, 0.06127843, 0.04881179, -0.07778106, 0.068567 ,
-0.02297281],
[-0.01108608, 0.07571883, -0.16907375, -0.08356024, 0.12267888,
-0.11135617, 0.01642111, 0.03403353, 0.08506832, -0.06965933,
-0.1217141 , -0.00500899, -0.24758602, 0.11082664, -0.07299259,
-0.0581366 ],
[ 0.08893831, 0.04202115, -0.06253906, -0.04473577, 0.0229234 ,
-0.12935649, -0.0163277 , 0.13139815, 0.13483017, -0.08191296,
-0.08125861, -0.07044632, -0.10374899, 0.03506476, 0.06903886,
-0.07480185]],
[[ 0.01747829, -0.08909235, 0.02053894, 0.01685996, 0.0389529 ,
-0.00709683, -0.00462855, 0.08696204, -0.03342217, -0.21788043,
-0.30605847, -0.1349446 , -0.11533038, -0.09036049, 0.15674459,
-0.126863 ],
[ 0.00525433, -0.10835235, 0.0027264 , 0.08466345, 0.14112028,
0.09043956, -0.04722811, 0.15436025, -0.00785834, -0.08981446,
-0.24305637, -0.00319932, 0.05808224, -0.12561799, 0.04564026,
-0.02585145],
[-0.06045172, -0.09233718, 0.22095997, 0.04789267, 0.03415064,
0.0262382 , 0.01256547, 0.10090689, -0.05170264, -0.13021805,
-0.0060994 , 0.07757477, 0.06438366, -0.10231362, 0.26221725,
0.02315334],
[ 0.15577114, -0.07492413, -0.01727874, -0.07142947, -0.06811142,
-0.1202229 , 0.05028568, 0.0590934 , 0.16453847, -0.23882856,
0.01029733, -0.15489522, 0.003675 , 0.02758079, 0.11096516,
-0.10430096],
[ 0.10145966, -0.10858013, -0.0650702 , -0.18066065, -0.00131774,
-0.08608264, 0.01451162, 0.09330788, 0.1964142 , -0.11445147,
-0.08806287, -0.15021895, -0.07837124, 0.07806647, 0.01695358,
-0.09840479],
[-0.01788662, -0.03722207, -0.0847649 , -0.10571478, 0.07626006,
-0.01792801, 0.01216043, 0.07324065, 0.10951483, 0.0420184 ,
-0.19723146, -0.02916167, -0.19582316, 0.09968804, -0.16278298,
-0.05184987],
[ 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. ],
[ 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. ],
[ 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. ],
[ 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. ]]])), 'final_states': (LSTMStateTuple(c=array([[ 0.35089277, 0.28169591, 0.10655786, -0.06038076, -0.04643542,
-0.08835599, 0.01115685, 0.05617012, 0.13590812, -0.45686304,
-0.03412385, -0.25160794, 0.0052527 , -0.01301877, 0.07284551,
0.0505886 ],
[ 0.09893459, -0.17312469, 0.01775045, -0.16009508, 0.06735036,
0.02437652, -0.00088507, 0.33421125, 0.29787557, -0.2440617 ,
-0.33060605, -0.29383554, -0.0913775 , -0.04287903, 0.01267763,
-0.18018947]]), h=array([[ 0.16813066, 0.12167612, 0.0333892 , -0.03621979, -0.02312098,
-0.04312613, 0.0032476 , 0.02581595, 0.07606242, -0.17718003,
-0.01771963, -0.08362358, 0.00321666, -0.00689595, 0.0280242 ,
0.01995513],
[ 0.04322471, -0.06838921, 0.00756469, -0.06242816, 0.02809441,
0.0122109 , -0.00049589, 0.14097181, 0.10700621, -0.11181731,
-0.2129191 , -0.11574361, -0.06110317, -0.02483002, 0.00724767,
-0.07834305]])), LSTMStateTuple(c=array([[ 0.05235715, 0.4449884 , 0.29395379, 0.10136493, 0.23366848,
0.23763545, 0.11706335, -0.09597744, 0.04718805, -0.0265409 ,
0.41051355, 0.0609024 , 0.50751285, -0.24001079, 0.1205789 ,
0.16522967],
[ 0.02442768, -0.26505515, 0.06781086, 0.03441641, 0.08381474,
-0.01540371, -0.01586349, 0.36816326, -0.06578872, -0.48389251,
-0.49614392, -0.30551384, -0.17350058, -0.11728567, 0.37299283,
-0.30650555]]), h=array([[ 0.02356511, 0.24838725, 0.17932081, 0.05942439, 0.15138192,
0.12581299, 0.07806339, -0.04999007, 0.02852398, -0.01691062,
0.12604886, 0.03830202, 0.23902503, -0.10546712, 0.06509197,
0.08460939],
[ 0.01747829, -0.08909235, 0.02053894, 0.01685996, 0.0389529 ,
-0.00709683, -0.00462855, 0.08696204, -0.03342217, -0.21788043,
-0.30605847, -0.1349446 , -0.11533038, -0.09036049, 0.15674459,
-0.126863 ]])))}
可以看到,对于forward来说,final_states 分别为 (c9,h9)和(c5,h5) ( c 9 , h 9 ) 和 ( c 5 , h 5 ) , 但是对于backward来说, finalstates=(c0,h0)和(c0,h0) f i n a l s t a t e s = ( c 0 , h 0 ) 和 ( c 0 , h 0 ) .