tf.contrib.rnn.MultiRNNCell
Aliases:
- Class tf.contrib.rnn.MultiRNNCell
- Class tf.nn.rnn_cell.MultiRNNCell
由多个简单的cells组成的RNN cell。用于构建多层循环神经网络。
__init__( cells, state_is_tuple=True ) |
参数:
- cells:RNNCells的list。
- state_is_tuple:如果为True,接受和返回的states是n-tuples,其中n=len(cells)。如果为False,states是concatenated沿着列轴.后者即将弃用。
代码实例:
import tensorflow as tf
batch_size=10
depth=128
inputs=tf.Variable(tf.random_normal([batch_size,depth]))
previous_state0=(tf.random_normal([batch_size,100]),tf.random_normal([batch_size,100]))
previous_state1=(tf.random_normal([batch_size,200]),tf.random_normal([batch_size,200]))
previous_state2=(tf.random_normal([batch_size,300]),tf.random_normal([batch_size,300]))
num_units=[100,200,300]
print(inputs)
cells=[tf.nn.rnn_cell.BasicLSTMCell(num_unit) for num_unit in num_units]
mul_cells=tf.nn.rnn_cell.MultiRNNCell(cells)
outputs,states=mul_cells(inputs,(previous_state0,previous_state1,previous_state2))
print(outputs.shape) #(10, 300)
print(states[0]) #第一层LSTM
print(states[1]) #第二层LSTM
print(states[2]) ##第三层LSTM
print(states[0].h.shape) #第一层LSTM的h状态,(10, 100)
print(states[0].c.shape) #第一层LSTM的c状态,(10, 100)
print(states[1].h.shape) #第二层LSTM的h状态,(10, 200)