tensorflow利用saver读取部分参数变量值

做实验时到的问题:

'feature_embeddings:0' not found in checkpoint

当时的实验室分别训练modelA和modelB,再将B模型的参数载入到A中,具体如下图所示:

在这里插入图片描述

modelA中包含modelB中所有参数,可以将modelA中参数载入ModelB,但是反过来则报错,具体载入语句为:

# 存储
saver = tf.train.Saver(max_to_keep=5) 
saver.save(self.sess, self.save_path + 'model.ckpt')


# 载入

def restore(sess, saver, save_path=None):
        print("载入Intract层参数!!!")
        if (save_path == None):
            save_path = self.save_path
        ckpt = tf.train.get_checkpoint_state(save_path)  
        if ckpt and ckpt.model_checkpoint_path:  
            saver.restore(sess, ckpt.model_checkpoint_path) 
            if verbose > 0:
                print ("restored from %s" % (save_path))

name如何解决modelB训练参数载入modelA呢?

参考:
xys430381_1 的博客
双木青橙 的博客
ncc1995 的博客

简言之:
就是在Saver内添加var_list参数,而且必须
两个模型都要添加!!!
两个模型都要添加!!!
两个模型都要添加!!!

# 获取模型所有训练参数
trainable_vars = tf.trainable_variables()
# 只保留 feature开始的参数
embed_var_list = [t for t in self.trainable_vars if t.name.startswith(u'feature')]
#只对embed_var_list进行存取
saver = tf.train.Saver(var_list=embed_var_list)

这样只存取 modelA 和 B 的b c d 参数即可,再次B模型参数载入A中便不会报错。

注:即使B模型中trainable_vars 中只包含 b c d 变量 Saver也必须填写 var_list 否则还会报错

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值