参考资料:https://www.xszz.org/faq-2/question-2018101955896.html
原代码
def decoder_lstm_based(h_decoder_in, encoder_final_state):
cell_1 = tf.contrib.rnn.BasicLSTMCell(config.n_hidden, activation=tf.nn.leaky_relu)
lstm_cell = tf.contrib.rnn.MultiRNNCell([cell_1 for _ in range(config.num_layers_encoder)], state_is_tuple=True)
init_state = ( encoder_final_state)
outputs, final_state = tf.nn.dynamic_rnn(lstm_cell, h_decoder_in, initial_state=init_state, time_major=False) #(batch_size,10,256)
outputs_reshaped = tf.reshape(outputs,[-1,config.n_hidden]) #(batch_size*10,256)
fc_1 = tf.layers.dense(outputs_reshaped,config.vocabulary_size)#(batch_size*10,21)
logits = tf.reshape(fc_1,[-1,config.sentence_length+1,config.vocabulary_size]) #(batch_size,10,21)
s_decoded = tf.nn.softmax(logits=logits, axis=2)
return s_decoded,logits
其中第二行和第三行出现错误,按照“https://www.xszz.org/faq-2/question-2018101955896.html”的指点:
You should not reuse the same cell for the first and deeper layers, because their inputs are different, hence kernel matrices are different. Try this:
# Extra function is for readability. No problem to inline it.
def make_cell(lstm_size):
return tf.nn.rnn_cell.BasicLSTMCell(lstm_size, state_is_tuple=True)
network = rnn_cell.MultiRNNCell([make_cell(num_units) for _ in range(num_layers)], state_is_tuple=True)
也就是说原来的代码是在第一层和第二层中使用了同一个cell,按照这个想法,我修改代码为:
def make_cell(lstm_size):
return tf.contrib.rnn.BasicLSTMCell(lstm_size)
def decoder_lstm_based(h_decoder_in, encoder_final_state):
#num_units = [config.n_hidden,config.n_hidden]
#cell_1 = [tf.contrib.rnn.BasicLSTMCell(num_units=n, activation=tf.nn.leaky_relu) for n in num_units]
lstm_cell = tf.contrib.rnn.MultiRNNCell([make_cell(config.n_hidden) for _ in range(config.num_layers_encoder)] , state_is_tuple=True)
init_state = ( encoder_final_state)
outputs, final_state = tf.nn.dynamic_rnn(lstm_cell, h_decoder_in, initial_state=init_state, time_major=False) #(batch_size,10,256)
outputs_reshaped = tf.reshape(outputs,[-1,config.n_hidden])#(batch_size*10,256)
fc_1 = tf.layers.dense(outputs_reshaped,config.vocabulary_size)#(batch_size*10,21)
logits = tf.reshape(fc_1,[-1,config.sentence_length+1,config.vocabulary_size]) #(batch_size,10,21)
s_decoded = tf.nn.softmax(logits=logits, axis=2)
return s_decoded,logits
成功通过!
当然,也可以选择另一种方式:
def decoder_lstm_based(h_decoder_in, encoder_final_state):
num_units = [config.n_hidden,config.n_hidden]
cell_1 = [tf.contrib.rnn.BasicLSTMCell(num_units=n, activation=tf.nn.leaky_relu) for n in num_units]
lstm_cell = tf.contrib.rnn.MultiRNNCell(cell_1, state_is_tuple=True)
init_state = encoder_final_state
outputs, final_state = tf.nn.dynamic_rnn(lstm_cell, h_decoder_in, initial_state=init_state, time_major=False) #(batch_size,10,256)
outputs_reshaped = tf.reshape(outputs,[-1,config.n_hidden])#(batch_size*10,256)
fc_1 = tf.layers.dense(outputs_reshaped,config.vocabulary_size)#(batch_size*10,21)
logits = tf.reshape(fc_1,[-1,config.sentence_length+1,config.vocabulary_size]) #(batch_size,10,21)
s_decoded = tf.nn.softmax(logits=logits, axis=2)
return s_decoded,logits