[tensorflow 中rnn学习笔记]RNN 中outputs 和 final_states

最近在看tensorflow中seq2seq的实现,对 rnn 返回的outputs和final_states 有点疑惑。按照我的理解,对于RNN 返回的outputs 是整个序列在time 上的展开,即 outputs = [h1,,hT] [ h 1 , ⋯ , h T ] ,而 final_states 应该为 hT h T , 但是如果这样的话显然不合适,因为返回一个outputs就可以了,那么 final_states 会不会是最后那个cell单元。看代码BasicLSTMCell中有个参数state_is_tuple,按照官网解释,state_is_tuple=True 表示返回的是 (cT,hT) ( c T , h T ) , 如果state_is_tuple=False表示返回的是 [cT;hT] [ c T ; h T ] ,表示 cT c T hT h T 的拼接。

import tensorflow as tf
import numpy as np
# 创建输入数据
X = np.random.randn(2, 10, 8)

# 第二个example长度为6
X[1, 6:] = 0
X_lengths = [10, 6]
cell = tf.contrib.rnn.BasicLSTMCell(num_units=16, state_is_tuple=True)
outputs, last_states = tf.nn.dynamic_rnn(
    cell=cell,
    dtype=tf.float64,
    sequence_length=X_lengths,
    inputs=X)
result = tf.contrib.learn.run_n(
    {"outputs": outputs, "final_states": last_states},
    n=1,
    feed_dict=None)

print(result[0])

输出结果为

{'outputs': array([[[  8.51600608e-02,   1.20660209e-01,  -8.30597066e-02,
           8.94003039e-02,  -4.79661940e-02,   6.33470364e-02,
          -1.56834621e-01,   4.85511425e-02,   3.02277184e-02,
          -8.24393029e-02,  -1.08487943e-01,  -5.45718743e-02,
           1.06717348e-01,  -3.63896366e-02,   8.62340956e-02,
           1.00542264e-01],
        [  5.26434674e-02,   8.82528803e-02,   8.18372465e-02,
          -4.70024490e-04,  -1.12268602e-01,   7.81389471e-02,
           1.56262981e-01,   1.37541835e-01,   4.07896288e-03,
           3.40363337e-02,  -1.06232519e-01,  -3.45220435e-03,
           4.45592240e-02,  -2.14986911e-01,  -5.29998580e-02,
          -4.47588031e-02],
        [ -1.15412606e-01,  -3.15060562e-02,  -3.53337836e-02,
          -1.41307868e-01,  -2.03983857e-01,   1.52195883e-01,
           5.93960067e-02,   5.22762772e-02,   6.51036954e-02,
           1.62629488e-02,  -6.71094027e-02,   1.78651720e-01,
           1.43935632e-02,  -8.98583193e-02,   8.51871234e-02,
           1.12750453e-01],
        [ -1.38490492e-01,   5.49294021e-02,  -5.79125883e-02,
          -1.34913379e-02,  -3.15632788e-02,   1.95746132e-01,
           4.04904864e-02,   1.36532347e-01,   8.69326328e-02,
           1.48219807e-01,  -6.75214799e-02,   6.74295126e-02,
          -7.10565155e-02,  -1.45115152e-01,   5.61782256e-02,
          -2.67229487e-02],
        [ -1.33728334e-01,   6.50316413e-03,  -8.28467300e-02,
          -2.13676304e-02,   1.99136425e-01,   2.07323407e-01,
          -7.04399740e-02,   6.64707052e-02,   1.28669041e-01,
          -3.99008089e-03,  -7.76149784e-02,   4.61886935e-02,
          -6.50376977e-02,  -1.18495263e-01,   1.17621248e-01,
           3.39481317e-02],
        [ -1.17954995e-01,   7.39067638e-02,  -4.73444202e-02,
           7.34059918e-02,   1.13222390e-01,   2.53695281e-01,
          -1.24950329e-01,   2.24618668e-02,   6.06518076e-02,
           5.96379696e-02,  -1.24584380e-01,   3.29196227e-02,
          -1.01975574e-01,   7.47888395e-02,   3.75111982e-02,
          -2.45382948e-03],
        [ -1.99299948e-01,  -7.39883224e-02,   9.04839265e-02,
           2.56283750e-02,   2.80005345e-01,   3.34384495e-01,
          -4.94238377e-03,   1.31329177e-01,   5.09313718e-02,
           2.07172398e-01,  -1.27001974e-01,   2.07658264e-01,
          -2.35004803e-01,   9.10458478e-02,  -3.33668414e-02,
          -4.22343926e-02],
        [ -6.93773303e-02,  -8.29434270e-02,   1.74555290e-01,
           4.32463735e-02,   1.92333551e-01,   8.05206310e-02,
          -4.67456180e-02,   2.04428783e-01,   5.42508192e-02,
          -1.78346205e-02,  -8.51749327e-02,   1.76002661e-01,
          -1.78980937e-01,   9.33172891e-02,  -7.65833685e-02,
           5.22866062e-02],
        [ -1.28708341e-01,  -1.07299070e-01,   1.74460808e-02,
          -1.41812305e-01,  -1.75518783e-01,   8.41806430e-02,
           8.74052786e-02,   8.54392758e-02,   8.94714372e-02,
          -7.60687913e-02,   5.32559011e-04,   2.12149332e-01,
          -1.76561832e-01,   1.10551504e-01,   5.45956606e-02,
           2.27662777e-01],
        [ -1.38716746e-01,   3.50782336e-02,   2.44049428e-02,
          -6.19580995e-02,  -1.22048162e-01,  -6.71233156e-03,
           1.49210897e-01,   2.68549670e-02,   4.08326486e-02,
          -3.16653869e-02,   1.05464326e-01,   1.64826060e-01,
          -1.17569311e-01,   1.11453975e-01,  -8.88766093e-02,
           7.84726110e-03]],

       [[ -5.85491016e-02,   1.69248395e-01,  -1.67613134e-01,
           2.40518851e-01,   5.23920082e-02,   9.71959834e-02,
          -2.93255173e-02,  -8.00785565e-02,   5.00090944e-02,
          -6.95419630e-02,   2.68765224e-02,  -1.63707918e-01,
           6.18403061e-02,  -8.36852013e-02,   2.69253483e-01,
          -7.63595695e-02],
        [ -3.14072350e-02,   2.39703454e-01,  -2.38231502e-01,
           9.26873152e-02,   8.10852132e-02,   1.27162577e-01,
          -5.73377612e-02,  -1.39968334e-01,   3.99482735e-02,
          -8.47072368e-02,   1.83764730e-02,  -2.14926501e-01,
           1.06305866e-04,   8.34569414e-03,   2.06806184e-01,
          -7.32392374e-02],
        [ -1.93460859e-02,   6.24097978e-02,  -1.03516662e-01,
          -1.26342672e-02,  -2.08758636e-02,   1.66454528e-01,
          -5.81779960e-02,  -9.20345358e-02,  -8.29544111e-03,
          -1.91938126e-03,  -8.22920924e-02,  -5.05529121e-02,
          -1.58349559e-02,   3.58339902e-02,   8.59197276e-02,
           2.68994241e-02],
        [  1.21182823e-01,  -1.60383966e-01,   5.01372858e-02,
          -1.05416048e-01,  -1.13352303e-01,   7.64203395e-03,
          -8.72742212e-02,  -2.58785716e-03,  -6.08090481e-02,
          -1.33383746e-02,  -1.75964119e-01,   4.07472718e-02,
          -1.35161904e-03,   2.60193309e-02,  -2.81093680e-02,
           2.01884647e-01],
        [  1.40900464e-01,  -9.81550810e-02,   3.84887706e-02,
          -1.17454971e-01,  -2.05570483e-01,  -1.04132781e-01,
          -3.25062432e-02,  -4.22595784e-02,  -7.68583127e-02,
          -1.44954773e-01,  -4.91470908e-02,  -2.19290014e-03,
           5.30027149e-02,   4.02266622e-02,  -2.62458365e-02,
           1.97159511e-01],
        [  6.42759069e-03,  -1.79224216e-01,   5.47793212e-02,
          -1.82429195e-01,  -2.33649002e-01,  -2.23675583e-02,
           6.23475515e-02,  -7.94249204e-02,  -4.07472203e-02,
          -1.65480823e-01,   4.92744710e-02,   1.17531775e-01,
          -6.34366232e-02,   1.08770149e-01,   1.52727682e-02,
           1.51494894e-01],
        [  0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
           0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
           0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
           0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
           0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
           0.00000000e+00],
        [  0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
           0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
           0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
           0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
           0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
           0.00000000e+00],
        [  0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
           0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
           0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
           0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
           0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
           0.00000000e+00],
        [  0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
           0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
           0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
           0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
           0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
           0.00000000e+00]]]), 'final_states': LSTMStateTuple(c=array([[-0.28682583,  0.07014927,  0.04587004, -0.137129  , -0.32802295,
        -0.01273347,  0.34420444,  0.05639795,  0.09310471, -0.05722055,
         0.186892  ,  0.36725005, -0.28890884,  0.23256402, -0.13752357,
         0.01838055],
       [ 0.01344121, -0.35903624,  0.1086343 , -0.38516179, -0.43465922,
        -0.04556597,  0.14225783, -0.18317949, -0.08140523, -0.29619229,
         0.07555464,  0.22769443, -0.10944077,  0.23001641,  0.04418935,
         0.28007862]]), h=array([[-0.13871675,  0.03507823,  0.02440494, -0.0619581 , -0.12204816,
        -0.00671233,  0.1492109 ,  0.02685497,  0.04083265, -0.03166539,
         0.10546433,  0.16482606, -0.11756931,  0.11145397, -0.08887661,
         0.00784726],
       [ 0.00642759, -0.17922422,  0.05477932, -0.18242919, -0.233649  ,
        -0.02236756,  0.06234755, -0.07942492, -0.04074722, -0.16548082,
         0.04927447,  0.11753178, -0.06343662,  0.10877015,  0.01527277,
         0.15149489]]))}

从上面的结果中可以看到,final_states 分别为 (c9,h9),(c5,h5) ( c 9 , h 9 ) , ( c 5 , h 5 ) ,因为第二个样本的time-step=6,所以才有 (c5,h5) ( c 5 , h 5 ) .
将 state_is_tuple=False 的运行结果如下

{'outputs': array([[[-0.0520232 , -0.00596184, -0.15686908, -0.12699241, -0.0998554 ,
         -0.13780395,  0.12679938,  0.13057581,  0.01394446, -0.03188085,
          0.13801467, -0.10640567,  0.00101057,  0.11505147,  0.14934653,
          0.00051895],
        [ 0.0188593 , -0.09666485, -0.22422633, -0.19876283, -0.14657046,
         -0.03839021,  0.23788255, -0.05810481,  0.17466428,  0.05607761,
          0.01819319, -0.11747983, -0.00355588,  0.05552798,  0.21469954,
          0.02979367],
        [-0.07972986, -0.09420505, -0.24408267, -0.18565055, -0.09185143,
         -0.0095228 ,  0.23303042, -0.00075869,  0.12077442,  0.04463999,
         -0.04770715, -0.06327784,  0.045044  ,  0.07344893,  0.24777579,
         -0.00886227],
        [ 0.00519919, -0.02781198,  0.05779533, -0.00535419,  0.09587848,
          0.11860488,  0.06507637, -0.06036546,  0.1939022 ,  0.15366904,
         -0.05568861,  0.06830101,  0.1332387 , -0.02283672,  0.0546639 ,
          0.12544396],
        [ 0.02893471, -0.08692809,  0.03823236,  0.07645208,  0.07632264,
          0.19249265,  0.04535065, -0.12270601,  0.16775191,  0.1155008 ,
         -0.08539387,  0.00234452, -0.04358566, -0.05436525, -0.02478817,
          0.2553227 ],
        [ 0.09466621, -0.03739176,  0.01791674,  0.12777455,  0.15859402,
          0.0646465 , -0.0835708 , -0.17655872,  0.10373822,  0.14420393,
          0.03712749, -0.00675274, -0.18880326, -0.17735889,  0.0100914 ,
          0.19743785],
        [ 0.15431932, -0.05348529,  0.05334238,  0.16724173,  0.18007884,
          0.03453815, -0.1851201 , -0.14165195,  0.00732323,  0.09458506,
          0.10690102,  0.02041647, -0.241884  , -0.2522594 , -0.00689178,
          0.1445705 ],
        [ 0.24555932,  0.09540935,  0.14411261,  0.34043302,  0.12379964,
         -0.1129966 , -0.19834936, -0.06646834,  0.12707064,  0.16486904,
          0.040797  ,  0.00081681, -0.19858868, -0.2392037 ,  0.05258687,
         -0.18024845],
        [ 0.08527663, -0.01553275, -0.04132372,  0.04184508, -0.04872827,
         -0.18967336, -0.07614696,  0.19419587, -0.04699756, -0.04536202,
          0.18612597, -0.03260128, -0.10163605, -0.15020102,  0.13656208,
         -0.07045483],
        [ 0.00902652,  0.00802411,  0.00349441,  0.0066363 ,  0.02812586,
         -0.2299042 ,  0.00387813,  0.24113983, -0.09704951, -0.10187069,
          0.13439021, -0.01243524, -0.08960348, -0.10325231,  0.15756624,
         -0.29002312]],

       [[-0.05148575,  0.06720759, -0.07206267, -0.06812521,  0.03614493,
          0.04767764, -0.09072717,  0.13137064, -0.16131891, -0.23117387,
          0.0621554 , -0.09076673,  0.1179775 ,  0.32040227, -0.05554256,
          0.09928288],
        [-0.09680876,  0.16121176, -0.13065313, -0.15421254,  0.07431257,
         -0.04082315, -0.02974331,  0.29206754, -0.15111419, -0.29595181,
          0.13282585, -0.06019135,  0.10809174,  0.37644085, -0.0440165 ,
          0.00685254],
        [-0.03421361,  0.13748931, -0.09908259, -0.22864179, -0.00654961,
         -0.0785063 , -0.02805491,  0.10213422, -0.13717202, -0.06699498,
          0.17477517,  0.0122198 ,  0.06943023,  0.44624913, -0.05628462,
          0.08216506],
        [ 0.01781805,  0.28241267, -0.05246943, -0.08493888, -0.02540539,
         -0.13000158, -0.03256825, -0.0878165 , -0.01878073,  0.26896141,
          0.13264408,  0.12267464,  0.0242984 ,  0.40396969, -0.12822519,
          0.15831019],
        [ 0.08107763,  0.12859033, -0.05616897, -0.03728299, -0.0312746 ,
         -0.08662511, -0.02553875, -0.0985618 ,  0.0650238 ,  0.03627828,
          0.12857739,  0.02939889, -0.03254754,  0.18302167, -0.0657735 ,
          0.12831271],
        [ 0.06255597, -0.00889141, -0.10634166, -0.09577743, -0.08439649,
         -0.12606134,  0.0083995 , -0.12306813, -0.03984412, -0.09111913,
          0.10249516, -0.05978723, -0.16387474,  0.17921828,  0.00099589,
          0.13509734],
        [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ],
        [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ],
        [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ],
        [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ]]]), 'final_states': array([[ 0.01704464,  0.01375266,  0.00569932,  0.01359386,  0.0599031 ,
        -0.5310615 ,  0.01106043,  0.63132237, -0.19392737, -0.22410165,
         0.24611573, -0.04944791, -0.21992734, -0.26848952,  0.28565425,
        -0.59903395,  0.00902652,  0.00802411,  0.00349441,  0.0066363 ,
         0.02812586, -0.2299042 ,  0.00387813,  0.24113983, -0.09704951,
        -0.10187069,  0.13439021, -0.01243524, -0.08960348, -0.10325231,
         0.15756624, -0.29002312],
       [ 0.12148278, -0.01844205, -0.3399909 , -0.23447282, -0.18236695,
        -0.18892043,  0.01784129, -0.27405052, -0.06386667, -0.22924761,
         0.16695596, -0.11379451, -0.25193653,  0.46508263,  0.00269928,
         0.34539569,  0.06255597, -0.00889141, -0.10634166, -0.09577743,
        -0.08439649, -0.12606134,  0.0083995 , -0.12306813, -0.03984412,
        -0.09111913,  0.10249516, -0.05978723, -0.16387474,  0.17921828,
         0.00099589,  0.13509734]])}

采用双向BiLSTM

outputs, last_states = tf.nn.bidirectional_dynamic_rnn(
                    cell_fw=cell,
                    cell_bw=cell,
                    inputs=X,
                    sequence_length=X_lengths,
                    dtype=tf.float64,
                    time_major=False
                )

运行结果如下

{'outputs': (array([[[ -8.41238987e-02,   2.02342451e-01,   2.60368436e-02,
          -6.91997149e-02,  -1.00815477e-02,  -1.84710375e-02,
           4.73017362e-02,  -1.46565564e-01,  -3.46681283e-02,
           1.13812954e-01,   1.12076022e-01,   9.64718047e-02,
           5.89444571e-02,   1.13485622e-02,  -5.93493561e-02,
           7.73265818e-02],
        [  1.90049974e-03,   3.13072057e-01,  -7.36657193e-02,
           5.85418739e-02,   1.73268038e-01,   6.05968861e-02,
           7.95394467e-02,  -4.45080027e-02,   5.93563046e-02,
           5.73828147e-02,   1.05410051e-01,   8.10913136e-02,
           7.32087617e-03,  -1.65634074e-02,  -3.33619727e-03,
          -3.27739570e-02],
        [  1.17158311e-01,   7.44036804e-02,  -3.29324582e-03,
           9.05722883e-02,   8.89695327e-02,   1.12105162e-01,
           1.00975061e-01,   5.15468347e-02,   5.14357339e-02,
           2.06787209e-02,  -2.83383422e-02,   2.11286105e-02,
           3.43771341e-02,  -7.92418082e-02,   8.53362033e-02,
          -8.16963366e-02],
        [ -4.25550595e-02,  -7.65800479e-04,   1.74860287e-01,
           9.02672757e-02,   4.04515710e-02,   1.11408791e-01,
           1.81697921e-01,   2.39430182e-02,  -8.38633094e-03,
           1.71085707e-01,   4.48776829e-02,   8.19355270e-02,
           1.65336573e-01,  -1.86586163e-01,   1.40200445e-01,
          -4.97211731e-02],
        [  9.68253056e-02,   6.06495878e-02,   2.08696567e-01,
           3.87971635e-02,   5.61433717e-02,   3.35399408e-02,
           1.11537949e-01,   4.59420685e-02,   1.41582280e-02,
           7.49530534e-03,   1.55239730e-02,  -4.05773611e-02,
           1.58862006e-01,  -2.09700998e-01,  -2.13514475e-03,
          -3.57567483e-03],
        [  1.32149638e-01,   9.30745932e-02,   2.46366699e-01,
           9.27931858e-02,   1.09161202e-01,   8.44090243e-02,
           6.41185248e-02,   6.26056890e-02,   7.41848483e-03,
          -5.86925305e-02,   3.52253794e-02,  -3.67522778e-02,
           9.78261676e-02,  -1.45995336e-01,  -6.36220468e-02,
           6.23390181e-02],
        [  2.32111735e-01,  -2.37044221e-02,   1.77017646e-01,
          -3.97338283e-04,   1.67123262e-03,   7.70140793e-02,
           6.53800462e-02,   1.22855548e-02,   2.31961441e-02,
          -1.26319736e-01,   6.41205262e-02,  -1.53632481e-01,
           8.28545616e-02,  -1.20705745e-01,  -4.10238845e-02,
           9.20385553e-02],
        [  2.66012734e-01,   1.05477379e-01,   4.21216468e-01,
           1.70095562e-01,  -4.77457867e-02,   1.68168287e-01,
           8.39350881e-02,  -3.55770945e-02,  -7.47878387e-02,
          -2.45388387e-01,   1.18436625e-01,  -7.06543538e-02,
           1.85329579e-01,  -1.84038439e-01,   7.32075603e-02,
           9.42068544e-02],
        [  7.95414388e-02,   1.16495551e-01,   1.97109709e-01,
           1.37920460e-02,  -2.43955191e-03,   1.06759844e-01,
           7.55164241e-02,  -1.15271431e-01,  -6.19764104e-02,
          -1.80050690e-01,   1.04730874e-01,  -9.92265522e-03,
           1.36772531e-01,  -9.17064895e-02,  -8.42246795e-02,
           1.01567247e-01],
        [  1.68130664e-01,   1.21676118e-01,   3.33891988e-02,
          -3.62197881e-02,  -2.31209753e-02,  -4.31261261e-02,
           3.24759979e-03,   2.58159477e-02,   7.60624235e-02,
          -1.77180028e-01,  -1.77196301e-02,  -8.36235778e-02,
           3.21666168e-03,  -6.89595460e-03,   2.80242017e-02,
           1.99551342e-02]],

       [[  1.56593441e-02,  -3.41630711e-02,   2.05353309e-02,
          -8.23938793e-02,  -9.61087545e-02,  -1.14604591e-01,
           2.42885439e-02,   3.98282281e-02,  -2.66879271e-02,
          -8.42482121e-02,  -2.03660879e-01,  -1.11103103e-01,
          -1.72252207e-01,   4.00229824e-02,   1.23120260e-01,
          -1.10854331e-01],
        [  5.32126508e-02,  -1.04440576e-01,  -9.68604478e-02,
           1.63228590e-02,   1.01150650e-01,  -1.02453878e-02,
          -1.33852475e-03,   1.64967070e-01,  -8.67544927e-03,
          -7.64610844e-02,  -3.54521460e-01,  -3.66237321e-02,
          -2.45614683e-01,   4.51928160e-02,   6.81918975e-02,
          -1.47578191e-01],
        [ -8.66380079e-02,  -1.01205470e-01,   1.25281612e-01,
           8.15776305e-02,   1.32941869e-01,   7.84578282e-02,
          -2.97582357e-02,   1.60467739e-01,  -1.51084771e-01,
          -6.00740157e-02,  -2.50723724e-01,   1.69384964e-01,
           2.09864500e-02,  -1.09103700e-01,   1.35223647e-01,
          -5.26320391e-02],
        [ -3.79775701e-02,  -9.84794905e-03,   1.04281425e-01,
           2.18962102e-01,   9.49885755e-02,   9.30334434e-02,
           1.20880917e-02,   4.92450967e-02,  -1.39313910e-01,
          -1.67242457e-01,  -8.92664684e-02,   5.76337605e-02,
           1.06973235e-01,  -1.88442473e-01,   1.16279010e-01,
          -6.63430587e-02],
        [  1.13913179e-01,  -7.01235476e-02,   8.72963709e-02,
           1.04267413e-01,  -5.76935594e-02,   3.81271792e-02,
           4.09102123e-03,   7.76470917e-02,  -6.38018038e-03,
          -2.10545446e-01,  -3.69164679e-02,  -1.20748943e-01,
           1.32309592e-01,  -1.45925061e-01,   1.47558076e-01,
          -7.00836697e-02],
        [  4.32247065e-02,  -6.83892073e-02,   7.56469023e-03,
          -6.24281552e-02,   2.80944091e-02,   1.22108990e-02,
          -4.95889882e-04,   1.40971813e-01,   1.07006212e-01,
          -1.11817313e-01,  -2.12919101e-01,  -1.15743608e-01,
          -6.11031704e-02,  -2.48300200e-02,   7.24766670e-03,
          -7.83430451e-02],
        [  0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
           0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
           0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
           0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
           0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
           0.00000000e+00],
        [  0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
           0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
           0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
           0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
           0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
           0.00000000e+00],
        [  0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
           0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
           0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
           0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
           0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
           0.00000000e+00],
        [  0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
           0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
           0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
           0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
           0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
           0.00000000e+00]]]), array([[[ 0.02356511,  0.24838725,  0.17932081,  0.05942439,  0.15138192,
          0.12581299,  0.07806339, -0.04999007,  0.02852398, -0.01691062,
          0.12604886,  0.03830202,  0.23902503, -0.10546712,  0.06509197,
          0.08460939],
        [ 0.14112674,  0.04603104,  0.16150333,  0.12562632,  0.16140194,
          0.13926239,  0.06831767,  0.13488513,  0.06390905, -0.09083964,
          0.03295962, -0.01229307,  0.21374753, -0.10903434,  0.11961082,
         -0.03469615],
        [ 0.25343172, -0.14971647,  0.2810062 ,  0.03923167, -0.01272575,
          0.07319106,  0.06496102,  0.04369333, -0.00583105, -0.14490683,
          0.04070688, -0.11250917,  0.24891372, -0.18076371,  0.10464212,
          0.05199773],
        [ 0.17880187, -0.04316056,  0.30915784, -0.03128962,  0.0136552 ,
          0.00114268,  0.11112965, -0.02416813, -0.01660347, -0.18926507,
          0.12892229, -0.09453105,  0.18057936, -0.17712924,  0.04842806,
          0.13493346],
        [ 0.2715969 ,  0.02844215,  0.17219262, -0.04691747,  0.07650773,
         -0.01564628,  0.05903233,  0.00140291,  0.03676231, -0.29671783,
          0.05205306, -0.10022505, -0.00608728, -0.11042412, -0.0303516 ,
          0.05871938],
        [ 0.17246075, -0.0050864 ,  0.13435615,  0.044755  ,  0.05607737,
          0.07413086,  0.09217622, -0.03587575,  0.02821449, -0.23358441,
          0.08262469, -0.05006317,  0.02842524, -0.0650939 ,  0.0275634 ,
          0.07439376],
        [ 0.18414017, -0.01139999,  0.07430255, -0.01925044, -0.02509538,
          0.04357774,  0.1246182 , -0.0357719 ,  0.02225083, -0.28293001,
          0.08652195, -0.08315333,  0.03829619, -0.06317638,  0.06140839,
          0.00854522],
        [ 0.11787692,  0.16528217,  0.0742104 ,  0.09235664,  0.10792024,
          0.05196128,  0.09740701, -0.02470125, -0.03088181, -0.220567  ,
          0.03927874,  0.06127843,  0.04881179, -0.07778106,  0.068567  ,
         -0.02297281],
        [-0.01108608,  0.07571883, -0.16907375, -0.08356024,  0.12267888,
         -0.11135617,  0.01642111,  0.03403353,  0.08506832, -0.06965933,
         -0.1217141 , -0.00500899, -0.24758602,  0.11082664, -0.07299259,
         -0.0581366 ],
        [ 0.08893831,  0.04202115, -0.06253906, -0.04473577,  0.0229234 ,
         -0.12935649, -0.0163277 ,  0.13139815,  0.13483017, -0.08191296,
         -0.08125861, -0.07044632, -0.10374899,  0.03506476,  0.06903886,
         -0.07480185]],

       [[ 0.01747829, -0.08909235,  0.02053894,  0.01685996,  0.0389529 ,
         -0.00709683, -0.00462855,  0.08696204, -0.03342217, -0.21788043,
         -0.30605847, -0.1349446 , -0.11533038, -0.09036049,  0.15674459,
         -0.126863  ],
        [ 0.00525433, -0.10835235,  0.0027264 ,  0.08466345,  0.14112028,
          0.09043956, -0.04722811,  0.15436025, -0.00785834, -0.08981446,
         -0.24305637, -0.00319932,  0.05808224, -0.12561799,  0.04564026,
         -0.02585145],
        [-0.06045172, -0.09233718,  0.22095997,  0.04789267,  0.03415064,
          0.0262382 ,  0.01256547,  0.10090689, -0.05170264, -0.13021805,
         -0.0060994 ,  0.07757477,  0.06438366, -0.10231362,  0.26221725,
          0.02315334],
        [ 0.15577114, -0.07492413, -0.01727874, -0.07142947, -0.06811142,
         -0.1202229 ,  0.05028568,  0.0590934 ,  0.16453847, -0.23882856,
          0.01029733, -0.15489522,  0.003675  ,  0.02758079,  0.11096516,
         -0.10430096],
        [ 0.10145966, -0.10858013, -0.0650702 , -0.18066065, -0.00131774,
         -0.08608264,  0.01451162,  0.09330788,  0.1964142 , -0.11445147,
         -0.08806287, -0.15021895, -0.07837124,  0.07806647,  0.01695358,
         -0.09840479],
        [-0.01788662, -0.03722207, -0.0847649 , -0.10571478,  0.07626006,
         -0.01792801,  0.01216043,  0.07324065,  0.10951483,  0.0420184 ,
         -0.19723146, -0.02916167, -0.19582316,  0.09968804, -0.16278298,
         -0.05184987],
        [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ],
        [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ],
        [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ],
        [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ]]])), 'final_states': (LSTMStateTuple(c=array([[ 0.35089277,  0.28169591,  0.10655786, -0.06038076, -0.04643542,
        -0.08835599,  0.01115685,  0.05617012,  0.13590812, -0.45686304,
        -0.03412385, -0.25160794,  0.0052527 , -0.01301877,  0.07284551,
         0.0505886 ],
       [ 0.09893459, -0.17312469,  0.01775045, -0.16009508,  0.06735036,
         0.02437652, -0.00088507,  0.33421125,  0.29787557, -0.2440617 ,
        -0.33060605, -0.29383554, -0.0913775 , -0.04287903,  0.01267763,
        -0.18018947]]), h=array([[ 0.16813066,  0.12167612,  0.0333892 , -0.03621979, -0.02312098,
        -0.04312613,  0.0032476 ,  0.02581595,  0.07606242, -0.17718003,
        -0.01771963, -0.08362358,  0.00321666, -0.00689595,  0.0280242 ,
         0.01995513],
       [ 0.04322471, -0.06838921,  0.00756469, -0.06242816,  0.02809441,
         0.0122109 , -0.00049589,  0.14097181,  0.10700621, -0.11181731,
        -0.2129191 , -0.11574361, -0.06110317, -0.02483002,  0.00724767,
        -0.07834305]])), LSTMStateTuple(c=array([[ 0.05235715,  0.4449884 ,  0.29395379,  0.10136493,  0.23366848,
         0.23763545,  0.11706335, -0.09597744,  0.04718805, -0.0265409 ,
         0.41051355,  0.0609024 ,  0.50751285, -0.24001079,  0.1205789 ,
         0.16522967],
       [ 0.02442768, -0.26505515,  0.06781086,  0.03441641,  0.08381474,
        -0.01540371, -0.01586349,  0.36816326, -0.06578872, -0.48389251,
        -0.49614392, -0.30551384, -0.17350058, -0.11728567,  0.37299283,
        -0.30650555]]), h=array([[ 0.02356511,  0.24838725,  0.17932081,  0.05942439,  0.15138192,
         0.12581299,  0.07806339, -0.04999007,  0.02852398, -0.01691062,
         0.12604886,  0.03830202,  0.23902503, -0.10546712,  0.06509197,
         0.08460939],
       [ 0.01747829, -0.08909235,  0.02053894,  0.01685996,  0.0389529 ,
        -0.00709683, -0.00462855,  0.08696204, -0.03342217, -0.21788043,
        -0.30605847, -0.1349446 , -0.11533038, -0.09036049,  0.15674459,
        -0.126863  ]])))}

可以看到,对于forward来说,final_states 分别为 (c9,h9)(c5,h5) ( c 9 , h 9 ) 和 ( c 5 , h 5 ) , 但是对于backward来说, finalstates=(c0,h0)(c0,h0) f i n a l s t a t e s = ( c 0 , h 0 ) 和 ( c 0 , h 0 ) .

Ref

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值