【TensorFlow】使用slim从ckpt里导出指定层的参数

利用函数 slim.assgin_from_checkpoint

相关文档:https://github.com/tensorflow/tensorflow/blob/129665119ea60640f7ed921f36db9b5c23455224/tensorflow/contrib/slim/python/slim/learning.py

 

相应的关于如何从一个ckpt文件里导出模型的部分参数用于fine_tuning:

*************************************************
* Fine-Tuning Part of a model from a checkpoint *
*************************************************
Rather than initializing all of the weights of a given model, we sometimes
only want to restore some of the weights from a checkpoint. To do this, one
need only filter those variables to initialize as follows:
  ...
  # Create the train_op
  train_op = slim.learning.create_train_op(total_loss, optimizer)
  checkpoint_path = '/path/to/old_model_checkpoint'
  # Specify the variables to restore via a list of inclusion or exclusion
  # patterns:
  variables_to_restore = slim.get_variables_to_restore(
      include=["conv"], exclude=["fc8", "fc9])
  # or
  variables_to_restore = slim.get_variables_to_restore(exclude=["conv"])
  init_assign_op, init_feed_dict = slim.assign_from_checkpoint(
      checkpoint_path, variables_to_restore)
  # Create an initial assignment function.
  def InitAssignFn(sess):
      sess.run(init_assign_op, init_feed_dict)
  # Run training.
  slim.learning.train(train_op, my_log_dir, init_fn=InitAssignFn)

Update:

使用如下代码:

import tensorflow as tf
import tensorflow.contrib.slim as slim
import tensorflow.contrib.slim.nets as nets

images = tf.placeholder(tf.float32, [None, 224, 224, 3])
predictions = nets.vgg.vgg_16(images)
print [v.name for v in slim.get_variables_to_restore(exclude=['fc8']) ]

得到输出:

[u'vgg_16/conv1/conv1_1/weights:0',
 u'vgg_16/conv1/conv1_1/biases:0',
 …
 u'vgg_16/fc6/weights:0',
 u'vgg_16/fc6/biases:0',
 u'vgg_16/fc7/weights:0',
 u'vgg_16/fc7/biases:0',
 u'vgg_16/fc8/weights:0',
 u'vgg_16/fc8/biases:0']

如果你不想导出某个层(vgg_16/fc8)的参数,

print [v.name for v in slim.get_variables_to_restore(exclude=['vgg_16/fc8']) ]

 

得到输出:

[u'vgg_16/conv1/conv1_1/weights:0',
 u'vgg_16/conv1/conv1_1/biases:0',
 …
 u'vgg_16/fc6/weights:0',
 u'vgg_16/fc6/biases:0',
 u'vgg_16/fc7/weights:0',
 u'vgg_16/fc7/biases:0']

update2:

一个完整的只取部分参数的例子:

import tensorflow as tf
import tensorflow.contrib.slim as slim
import tensorflow.contrib.slim.nets as nets

s = tf.Session(config=tf.ConfigProto(gpu_options={'allow_growth':True}))

images = tf.placeholder(tf.float32, [None, 224, 224, 3])
predictions = nets.vgg.vgg_16(images, 200)
variables_to_restore = slim.get_variables_to_restore(exclude=['vgg_16/fc8'])
init_assign_op, init_feed_dict = slim.assign_from_checkpoint('./vgg16.ckpt', variables_to_restore)
s.run(init_assign_op, init_feed_dict)

以上是从官方提供的slim vgg16在imagenet上训练得到的model中导出部分参数code

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值