在TensorFlow中使用BasicLSTMCell和MultiRNNCell实现多层LSTM循环神经网络时,关于MultiRNNCell的用法,网络上很多例程中是错误的,误人子弟
- 下面这样直接用一个BasicLSTMCell复制是错误的,会导致各层共享权重
basic_cell = tf.nn.rnn_cell.BasicLSTMCell(rnn_unit) multi_cell = tf.nn.rnn_cell.MultiRNNCell([basic_cell]*layer_num)
- 在新版本TensorFlow源码中可以看到,上面这样的写法会给出警告:
if len(set([id(cell) for cell in cells])) < len(cells): logging.log_first_n(logging.WARN, "At least two cells provided to MultiRNNCell " "are the same object and will share weights.", 1)
- 官方推荐的写法,使用列表生成器:
num_units = [128, 64] cells = [BasicLSTMCell(num_units=n) for n in num_units] stacked_rnn_cell = MultiRNNCell(cells)