TensorFlow之Fine tune
TensorFlow中可以通过加载已训练模型的部分参数对当前训练任务进行微调
一、tf.train.Saver类中恢复部分参数的方法
在创建saver对象时,将模型中需要恢复的参数(名称)以字典的形式,作为参数传递:
with tf.variable_scope('conv_layer1-other'):
conv1_weights = tf.get_variable("weight", [CONV1_SIZE, CONV1_SIZE, NUM_CHANNELS, CONV1_DEEP], initializer=tf.truncated_normal_initializer(stddev=0.1))
......
with tf.variable_scope('conv_layer2-other'):
conv1_weights = tf.get_variable("weight", [CONV2_SIZE, CONV2_SIZE, CONV1_DEEP, CONV2_DEEP], initializer=tf.truncated_normal_initializer(stddev=0.1))
......
variables_to_restore = {'conv_layer1/weight': conv1_weights, 'conv_layer2/weight': conv2_weights}
init_op = tf.initialize_all_variables()
with tf.Session() as sess:
sess.run(init_op )
saver.restore(sess, model_name)
......
上述程序实现将模型中名称为:'conv_layer1\weight'和'conv_layer2\weight'的参数值,加载给当前网络中的变量conv1_weights和conv2_weights(名称为:'conv_layer1-other\weight'和'conv_layer2-other\weight')
二、slim模块中的方法
exclude = ['conv_layer1/weight', 'conv_layer2/weight']
variables_to_restore = slim.get_variables_to_restore(exclude=exclude )
init_fn = slim.assign_from_checkpoint_fn(model_path, variables_to_restore)
# Start training.
slim.learning.train(train_op, log_dir, init_fn=init_fn, ignore_missing_vars=ignore_missing_vars)
其中,ignore_missing_vars为True时,将会忽略不存在参数