利用TensorFlow进行模型finetune时,想训练指定的op,则需要根据指定的op名进行提取,然后将其配置至tf.train.AdamOptimizer函数中的var_list参数。
部分代码如下:
def _get_variables_to_train(trainable_scopes = None):
"""Returns a list of variables to train.
Returns:
A list of variables to train by the optimizer.
"""
if trainable_scopes is None:
return tf.trainable_variables()
else:
scopes = [scope.strip() for scope in trainable_scopes.split(',')]
variables_to_train = []
for scope in scopes:
variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)
variables_to_train.extend(variables)
return variables_to_train
output_vars = _get_variables_to_train(Config.trainable_scopes)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = tf.train.AdamOptimizer(learning_rate=lr).minimize(loss, global_step=global_step, var_list=output_vars)
trainable_scopes = 'HCnet/Bottle_neck5,HCnet/Bottle_neck5_1,HCnet/Bottle_neck6_2,HCnet/Conv7'
trainable_scopes是我需要训练的op,通过函数_get_variables_to_train获取需要的op参数,然后将获取到的参数output_vars传入tf.train.AdamOptimizer中的var_list中,通过该方法即可对指定的op训练。上述过程省略了模型预加载的过程。