tensorflow增加新的层后重载模型部分参数

踩坑实录:
在做迁移学习的时候经常会碰到 增加了新的层却需要调取已有模型的部分参数的情况
因为已有的checkpoint里并没有新加入层的variables,报错为:

Key xxx not found in checkpoint

可以通过get_collection/看看该层的所有variables
var = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope=‘新加入层的scope’)
或者
var=slim.get_variables('新加入层的scope')

在restore的时候,知道哪些variables是与新加入的层相关的之后,exclude这些variables就好了,(但我在解决实际问题的时候 发现tf无法正则匹配scope的关键字 比如新增加的scope为 scope1=model/inference/dense 那么这个scope属于绝对路径 比如scope2=model/optimizer/model/inference/dense 这个就得重新写入exclude list 我们只能比较前后有哪些key是不同的 然后剔除这些不同的key):

var_to_restore=tf.contrib.framework.get_variables_to_restore(exclude=['scope1','scope2'...])
saver=tf.train.Saver(var_to_restore,max_to_keep=xx)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值