tensorflow导入部分checkpoint

版权声明: https://blog.csdn.net/b876144622/article/details/79962727
现实中碰到一个问题,训练好分类模型,比如训练保存了一个10分类的模型,但是实际用的时候呢,可能是做20分类,但是还想继续使用前面保存的模型。那么相当于是只加载前几层的参数,最后一层做一些修改。


一般实验情况下保存的时候,都是用的saver类来保存,如下

saver = tf.train.Saver()
saver.save(sess,"model.ckpt")
加载时的代码

saver.restore(sess,"model.ckpt")


前面的描述相当于是保存了所有的参数,然后加载所有的参数。但是目前的情况有所变化了,不能加载所有的参数,最后一层的参数不一样了,需要随机初始化。如何操作呢?

首先对每一层添加name scope,如下:

with name_scope('conv1'):
        xxx
with name_scope('conv2'):
        xxx
with name_scope('fc1'):
        xxx
with name_scope('output'):
        xxx

然后根据变量的名字,选择加载哪些变量,

#得到该网络中,所有可以加载的参数
variables = tf.contrib.framework.get_variables_to_restore()
#删除output层中的参数
variables_to_resotre = [v for v in varialbes if v.name.split('/')[0]!='output']
#构建这部分参数的saver
saver = tf.train.Saver(variables_to_restore)
saver.restore(sess,'model.ckpt')

在tensorflow中,有多种方式可以得到变量的信息:

tf.contrib.framework.get_variables_to_restore()
tf.all_variables()
tf.trainable_varialbes()

等等,可以多看看API

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值