tensorflow 恢复部分参数、加载指定参数

多分类采用与训练模型输出不匹配,我们需要加载部分预训练模型的参数。

我们先看一下如何保存和读入预训练模型。

#一般实验情况下保存的时候,都是用的saver类来保存,如下
saver = tf.train.Saver()
saver.save(sess,"model.ckpt")

#加载时的代码
saver.restore(sess,"model.ckpt")

#前面的描述相当于是保存了所有的参数,然后加载所有的参数。
#但是目前的情况有所变化了,不能加载所有的参数,最后一层的参数不一样了,需要随机初始化。
#首先对每一层添加name scope,如下:

with name_scope('conv1'):
        xxx
with name_scope('conv2'):
        xxx
with name_scope('fc1'):
        xxx
with name_scope('output'):
        xxx
#然后根据变量的名字,选择加载哪些变量,

#得到该网络中,所有可以加载的参数
variables = tf.contrib.framework.get_variables_to_restore()
#删除output层中的参数
variables_to_resotre = [v for v in varialbes if v.name.split('/')[0]!='output']
#构建这部分参数的
saversaver = tf.train.Saver(variables_to_restore)
saver.restore(sess,'model.ckpt')

#在tensorflow中,有多种方式可以得到变量的信息:
tf.contrib.framework.get_variables_to_restore()
tf.all_variables()tf.trainable_varialbes()

 

 

 

 

多分类采用与训练模型输出不匹配解决方法:

利用tf.contrib.framework.get_variables_to_restore()函数,代码如下

variables_to_restore = tf.contrib.framework.get_variables_to_restore(exclude=['resnet50/fc'])
saver = tf.train.Saver(variables_to_restore)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.restore(sess, param_path)

 

exclude=['resnet50/fc']表示加载预训练参数中除了resnet50/fc这一层之外的其他所有参数。

include=["inceptionv3"]表示只加载inceptionv3这一层的所有参数。

param_path是你预训练参数保存地址。

注:如果不止一个层参数需要丢弃,exclue=['a', 'b']即可。调优训练(fine_tuning)时最好把前面曾trainable设为False,只训练最后一层。
 

  • 6
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值