TensorFlow预训练模型在新图中权重部分加载

首先对预训练模型的scope一定要做好定义,不然恢复起来会比较麻烦。

这里使用tf.get_collection()

1、tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope='name')
tf.get_collection(
    key,
    scope=None
)

Args:

  • key: The key for the collection. For example, the GraphKeys class contains many standard names for collections.
  • scope: (Optional.) If supplied, the resulting list is filtered to include only items whose name attribute matches using re.match. Items without a name attribute are never returned if a scope is supplied and the choice or re.match means that a scope without special tokens filters by prefix.

Returns:

The list of values in the collection with the given name, or an empty list if no value has been added to that collection. The list contains the values in the order under which they were collected.

 tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)用于获取当前图下,给定指定name的所有变量,并返回由这些变量构成的list。

2、申请saver

saver = tf.train.Saver(var_list=var)

这里表示当前的这个saver只对var中的变量进行恢复,其余的不管

3、载入之前预训练的ckpt

saver.restore(sess,MODELPATH)

这里表示指定恢复的变量的权重是从MODELPATH里面来的,MODELPATH是之前预训练模型的ckpt

如果有多个这样的scope需要恢复的话可以多次重复上述步骤

最后,代码总结

var = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope='name1')
saver.restore(sess,MODELPATH1)

var = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='name2')
saver = tf.train.Saver(var_list=var)

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值