- 模型分步训练方法综述
对于较为复杂(如CNN与LSTM级联),或者网络层数过深的深度学习模型,在训练过程中,如果采用直接训练的方式,会导致训练难度极大,甚至模型收敛性降低。此时需要采用固定部分参数,训练部分参数的方式来分开训练整个庞大模型,在完成模型的分布训练之后,模型的各个部分均能较好的收敛,并保存得到优良的参数。随后在测试过程中,分别恢复保存好的这两部分参数,完成模型的测试工作。下面以CNN与LSTM级联模型为例,介绍分步训练方法。
- 要点分类
第一,在构建模型时,需要预先为模型中的各个子模块设定变量范围scope,便于变量管理;以下是设置scope的几种方式。
with tf.variable_scope('inference', reuse = tf.AUTO_REUSE):
xxxxx
conv6 = slim.conv2d(conv5, 128, [3,3], stride=2, padding='VALID', scope='conv2d_1a_3x3')
第二,因为后续涉及到模型参数的分开保存,需要首先设法获取两类参数各自的集合;
var_list = tf.trainable_variables() #获取所有变量
#根据变量名相关特征获取参数集合
lstm_variables = [v for v in var_list if v.name.split('/')[1].split('_')[0] == 'lstm']
cnn_variables = [v for v in var_list if v.name.split('/')[1].split('_')[0] != 'lstm']
第三,设置不同的saver与相应检查点路径,用于保存和恢复对应类别的参数
saver = tf.train.Saver(cnn_variables) #saver用于保存部分参数:cnn_variable
saver2 = tf.train.Saver(lstm_variables) #saver2用于保存部分参数lstm_variable
saver1 = tf.train.Saver() #saver1默认保存神经网络会话中的所有参数
ckpt = tf.train.get_checkpoint_state(logs_dir_lstm) #此处的log_dir_lstm对应lstm参数保存路径,如果有预训练参数,可以在预训练的基础上fine tune
if ckpt and ckpt.model_checkpoint_path:
saver2.restore(sess, ckpt.model_checkpoint_path) #此处调用saver2
print("Model restored...")
第四,训练过程中保存相关参数
if index % 1000 == 0 or index == MAX_STEP - 1:
checkpoint_path = os.path.join(logs_dir_cnn, 'model.ckpt') #设定参数保存路径
saver.save(sess, checkpoint_path, global_step= index) #指定相应saver,保存对应参数
-指定训练部分参数
参数的训练与模型参数的保存是两个不同的问题,在训练时同样可以通过指定参数范围来控制需要训练的变量。下边的代码中,通过设定FLAGS.tune_scope来指定训练变量范围。
def _get_training(logits, labels):
"""Set up training ops"""
with tf.name_scope("train"):
if FLAGS.tune_scope:
scope = FLAGS.tune_scope
else:
scope = r"inference/conv"
# scope = r"inference/lstm_layer"
# rnn_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope)
cnn_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope)
# rnn_vars = var_filter(tf.trainable_variables(), last_layers = range(2))
global_step = tf.Variable(0)
trainable_var_collection.append(cnn_vars)
loss = losses(logits, labels)
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
tf.summary.scalar('accuracy', accuracy)
extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(extra_update_ops):
learning_rate = tf.train.exponential_decay(
FLAGS.learning_rate,
global_step,
FLAGS.decay_steps,
FLAGS.decay_rate,
staircase=FLAGS.decay_staircase,
name='lr')
optimizer = tf.train.AdamOptimizer(
learning_rate=learning_rate,
beta1=FLAGS.momentum)
train_op = tf.contrib.layers.optimize_loss(
loss=loss,
global_step=tf.train.get_global_step(),
learning_rate=learning_rate,
optimizer=optimizer,
variables=trainable_var_collection)
tf.summary.scalar('learning_rate', learning_rate)
tf.summary.scalar('loss', loss)
return train_op, loss