在使用经典的网络模型(如VGG16,下文来VGG16来讲述)用于自己的任务task时,可根据自己的需求选择是否使用(VGG16)在ImageNet预训练的权重来恢复模型参数:
- 不使用预训练的参数,自己训练全部参数;当数据量足够时,这种方法能充分发挥模型的威力,取得较高的性能;
- 使用预训练的权重恢复除fc8层参数之外的参数,只训练fc8一层,相当于把VGG16模型当成一个特征提取器,用fc7层提起的特征做一个SoftMax模型分类,这样做的优势在于训练难度小,所需的数据少并且训练速度快,但是往往模型性能不会太好;
- 恢复部分层参数,训练部分参数;通常是固定浅层参数不变,训练深层参数,如固定conv1-conv5,训练fc6,fc7,fc8
一 从预训练的checkpoint恢复参数的方法
(1)使用tf.train.Saver恢复模型的部分或全部参数
# Create some variables. v1 = tf.Variable(..., name="v1") v2 = tf.Variable(..., name="v2") ... # Add ops to restore all the variables. restorer = tf.train.Saver() # Add ops to restore some variables. restorer = tf.train.Saver([v1, v2]) # Later, launch the model, use the saver to restore variables from disk, and # do some work with the model. with tf.Session() as sess: # Restore variables from disk. restorer.restore(sess, "/tmp/model.ckpt") print("Model restored.") # Do some work with the model ...
(2)通过变量名从checkpoint恢复变量
TF-slim提供了一系列的辅助函数去选择当前graph中变量的子集
# Create some variables. v1 = slim.variable(name="v1", ...) v2 = slim.variable(name="nested/v2", ...) ... # Get list of variables to restore (which contains only 'v2'). These are all # equivalent methods: variables_to_restore = slim.get_variables_by_name("v2") # or variables_to_restore = slim.get_variables_by_suffix("2") # or variables_to_restore = slim.get_variables(scope="nested") # or variables_to_restore = slim.get_variables_to_restore(include=["nested"]) # or variables_to_restore = slim.get_variables_to_restore(exclude=["v1"]) # Create the saver which will be used to restore the variables. restorer = tf.train.Saver(variables_to_restore) with tf.Session() as sess: # Restore variables from disk. restorer.restore(sess, "/tmp/model.ckpt") print("Model restored.") # Do some work with the model ...
在从checkpoint恢复变量时,可使用restorer=tf.train.Saver()创建一个恢复器,tf.train.Saver()会查找checkpoint中的变量名,并将他们匹配到在当前图中的变量,也可以在创建tf.train.Saver()时传递一个要恢复的变量列表,这种情况下,tf.train.Saver要在checkpoint中查找的变量名是隐含的从变量列表中获得的var.op.name
但有时要恢复的变量在checkpoint中的变量名与当前graph中的变量名不同是如何恢复呢
(3)给tf.trian.Saver()提供一个字典
要恢复的变量在checkpoint中的变量名与当前graph中的变量名不同时,在创建tf.train.Saver()时,需要提供一个每个变量在checkpoint中的名字到该变量在当前graph中名字映射的字典;可以通过下例中简单的方法创建这个映射字典。
# Assuming than 'conv1/weights' should be restored from 'vgg16/conv1/weights' def name_in_checkpoint(var): return 'vgg16/' + var.op.name # Assuming than 'conv1/weights' and 'conv1/bias' should be restored from 'conv1/params1' and 'conv1/params2' def name_in_checkpoint(var): if "weights" in var.op.name: return var.op.name.replace("weights", "params1") if "bias" in var.op.name: return var.op.name.replace("bias", "params2") variables_to_restore = slim.get_model_variables() variables_to_restore = {name_in_checkpoint(var):var for var in variables_to_restore} restorer = tf.train.Saver(variables_to_restore) with tf.Session() as sess: # Restore variables from disk. restorer.restore(sess, "/tmp/model.ckpt")
二 一个栗子
在有1000类的ImageNet预训练的VGG16用于20分类的情况下,使用pre_trained VGG16(出去最后一层)中的数据来初始化模型
# Load the Pascal VOC data image, label = MyPascalVocDataLoader(...) images, labels = tf.train.batch([image, label], batch_size=32) # Create the model predictions = vgg.vgg_16(images) train_op = slim.learning.create_train_op(...) # Specify where the Model, trained on ImageNet, was saved. model_path = '/path/to/pre_trained_on_imagenet.checkpoint' # Specify where the new model will live: log_dir = '/path/to/my_pascal_model_dir/' # Restore only the convolutional layers: variables_to_restore = slim.get_variables_to_restore(exclude=['fc6', 'fc7', 'fc8']) init_fn = assign_from_checkpoint_fn(model_path, variables_to_restore) # Start training. slim.learning.train(train_op, log_dir, init_fn=init_fn)