Tensorflow 模型加载及部分变量初始化

最近在做预训练部分图模型,将这部分图模型重新加载到一个新的图中,并加入一些新的op。下面是一些遇到的问题,调试方法以及解决方案。

1、从已有图中restore参数

saver_restore = tf.train.import_meta_graph(meta_path_restore)
saver_restore.restore(sess,<checkpoint path>)

2、通过tensor的名字获取变量

input_y = saver_restore.get_tensor_by_name('name:0')

P.S.在实验过程中,我自己尝试了一种方法,在外部创建会话,直接将需要加载的参数通过会话加载进来也是可以的。

sess = tf.Session()
model1= Model1(Session=sess, restore='Model1参数path')
model2= Model2(Session=sess, restore='Model2参数path')

这样也是可行的,但是如果加入其它操作的话会出现attempting to use uninitialized value。出现这个问题的原因是在部分加载变量后,添加了其它操作,同时又没有做对未赋值变量做initialize。解决办法

1、先添加需要的其它操作,然后运行

tf.global_variables_initializer()

随后对部分参数加载。这里需要注意,如果是最后执行global_variables_initializer()的话,之前所有的赋值操作都会被覆盖掉,也即之前做的所有操作都是无意义的。

2、加载参数后,加入新的操作,最后对没有初始化的部分参数进行初始化操作。

uninit_vars=[]
for var in tf.all_variables():
    try:
       sess.run(var)
    except tf.errors.FailedPreconditionError:
       uninit_vars.append(var)
init_new_vars_op = tf.initialize_variables(uninit_vars)
sess.run(init_new_vars_op)

这里使用的是部分参数初始化,通过这种方法就可以避免需要加载参数后再加入其他操作无法初始化参数的问题。

附加一个可以用于查看参数变量的代码,方便调试使用:

var = tf.trainable_variables()
value = sess.run(var)
for v in value:
    print(v)

https://blog.csdn.net/ying86615791/article/details/76215363这篇写的也很好,可供参考。

  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值