tf.contrib.slim学习之微调模型fine-truning

在使用经典的网络模型(如VGG16,下文来VGG16来讲述)用于自己的任务task时,可根据自己的需求选择是否使用(VGG16)在ImageNet预训练的权重来恢复模型参数:

  • 不使用预训练的参数,自己训练全部参数;当数据量足够时,这种方法能充分发挥模型的威力,取得较高的性能;
  • 使用预训练的权重恢复除fc8层参数之外的参数,只训练fc8一层,相当于把VGG16模型当成一个特征提取器,用fc7层提起的特征做一个SoftMax模型分类,这样做的优势在于训练难度小,所需的数据少并且训练速度快,但是往往模型性能不会太好;
  • 恢复部分层参数,训练部分参数;通常是固定浅层参数不变,训练深层参数,如固定conv1-conv5,训练fc6,fc7,fc8


一 从预训练的checkpoint恢复参数的方法

(1)使用tf.train.Saver恢复模型的部分或全部参数

# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add ops to restore all the variables.
restorer = tf.train.Saver()

# Add ops to restore some variables.
restorer = tf.train.Saver([v1, v2])

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
  # Restore variables from disk.
  restorer.restore(sess, "/tmp/model.ckpt")
  print("Model restored.")
  # Do some work with the model
  ...

(2)通过变量名从checkpoint恢复变量

TF-slim提供了一系列的辅助函数去选择当前graph中变量的子集

# Create some variables.
v1 = slim.variable(name="v1", ...)
v2 = slim.variable(name="nested/v2", ...)
...

# Get list of variables to restore (which contains only 'v2'). These are all
# equivalent methods:
variables_to_restore = slim.get_variables_by_name("v2")
# or
variables_to_restore = slim.get_variables_by_suffix("2")
# or
variables_to_restore = slim.get_variables(scope="nested")
# or
variables_to_restore = slim.get_variables_to_restore(include=["nested"])
# or
variables_to_restore = slim.get_variables_to_restore(exclude=["v1"])

# Create the saver which will be used to restore the variables.
restorer = tf.train.Saver(variables_to_restore)

with tf.Session() as sess:
  # Restore variables from disk.
  restorer.restore(sess, "/tmp/model.ckpt")
  print("Model restored.")
  # Do some work with the model
  ...

 在从checkpoint恢复变量时,可使用restorer=tf.train.Saver()创建一个恢复器,tf.train.Saver()会查找checkpoint中的变量名,并将他们匹配到在当前图中的变量,也可以在创建tf.train.Saver()时传递一个要恢复的变量列表,这种情况下,tf.train.Saver要在checkpoint中查找的变量名是隐含的从变量列表中获得的var.op.name

但有时要恢复的变量在checkpoint中的变量名与当前graph中的变量名不同是如何恢复呢

(3)给tf.trian.Saver()提供一个字典

要恢复的变量在checkpoint中的变量名与当前graph中的变量名不同时,在创建tf.train.Saver()时,需要提供一个每个变量在checkpoint中的名字到该变量在当前graph中名字映射的字典;可以通过下例中简单的方法创建这个映射字典。

# Assuming than 'conv1/weights' should be restored from 'vgg16/conv1/weights'
def name_in_checkpoint(var):
  return 'vgg16/' + var.op.name

# Assuming than 'conv1/weights' and 'conv1/bias' should be restored from 'conv1/params1' and 'conv1/params2'
def name_in_checkpoint(var):
  if "weights" in var.op.name:
    return var.op.name.replace("weights", "params1")
  if "bias" in var.op.name:
    return var.op.name.replace("bias", "params2")

variables_to_restore = slim.get_model_variables()
variables_to_restore = {name_in_checkpoint(var):var for var in variables_to_restore}
restorer = tf.train.Saver(variables_to_restore)

with tf.Session() as sess:
  # Restore variables from disk.
  restorer.restore(sess, "/tmp/model.ckpt")

二 一个栗子

在有1000类的ImageNet预训练的VGG16用于20分类的情况下,使用pre_trained VGG16(出去最后一层)中的数据来初始化模型

# Load the Pascal VOC data
image, label = MyPascalVocDataLoader(...)
images, labels = tf.train.batch([image, label], batch_size=32)

# Create the model
predictions = vgg.vgg_16(images)

train_op = slim.learning.create_train_op(...)

# Specify where the Model, trained on ImageNet, was saved.
model_path = '/path/to/pre_trained_on_imagenet.checkpoint'

# Specify where the new model will live:
log_dir = '/path/to/my_pascal_model_dir/'

# Restore only the convolutional layers:
variables_to_restore = slim.get_variables_to_restore(exclude=['fc6', 'fc7', 'fc8'])
init_fn = assign_from_checkpoint_fn(model_path, variables_to_restore)

# Start training.
slim.learning.train(train_op, log_dir, init_fn=init_fn)

 

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值