Tensorflow 图学习 模型保存,导入,再训练,再保存。

题目起这么长,其实是说怎么用tensorflow做模型的保存,导入,fine-tune/再训练。
碰到问题时感觉一头雾水,只怪自己学艺不专,从CSDN大学各个大佬帖子中都是东一榔头,西一棒子。学得很不系统。
但有多少人从一开始就非常系统呢?

而当我掌握了上面所说的技能的时候,我想说,真简单!
可能是会了就不难了吧。
我就把自己的code贴在这里,希望对还在入门tensorflow的人有所帮助。

我曾遇到的问题,都是在尝试过程中犯的非常简单的问题:

  1. AssertionError: Do not use tf.reset_default_graph() to clear nested graphs.
  2. FailedPreconditionError (see above for traceback): Attempting to use uninitialized value conv2d_1/kernel
  3. Not found: Key gcnmodelae/graphconvolution_1_vars/weights/Adam not found in checkpoint
  4. Attempting to use uninitialized value Variable
  5. NotFoundError: Restoring from checkpoint failed.
  6. The Session graph is empty.
  7. Cannot interpret feed_dict key as Tensor: Tensor Tensor (...) is not an element of this graph.
  8. NotFoundError: Restoring from checkpoint failed.

入坑了tensorflow,使用图学习模型做无监督特征提取。主要是使用graph autoencoder(GAE) 提特征,然后做分类。

因为已经有了经过预训练的网络,就想要使用预训练的网络权值做初始化,然后利用新的数据集重新训练网络。自己花了五天,踩了挺多莫名奇妙的坑,看到网上的教程很多,但是还没有一个是我这样的需求,就记录一下,分享下 踩坑经验。

总结了一下,大概有这么三个命令十分关键,大部分的问题都来自这三个命令。

 1. saver = tf.train.Saver() # tf.trainable_variables()
 2. sess.run(tf.global_variables_initializer())
 3. saver.restore(sess_anch, checkpoint_pth)

除了这些,如何定义session和graph也挺重要。

g_anchor = tf.Graph()
sess_anch = tf.Session(graph=g_anchor)

按照我的理解,我遇到的问题,以及解决办法如下。短短几天的了解还比较有限,就持续更新吧。

1. 模型保存

saver = tf.train.Saver() # tf.trainable_variables()
位置的错放,可能导致想保存的参数保存不上。

sess.run(tf.global_variables_initializer())
位置的错放,可能导致已保存的参数不被更新,或者被重置为初始化的值。

2. 模型加载

sess.run(tf.global_variables_initializer())
位置的错放,可能导致已保存的参数不被更新。

3. 模型再训练

saver = tf.train.Saver()

加不加具体的meta文件路径,会导致模型的参数无法被更新。

4. 再训练模型保存

saver.restore(sess_anch, checkpoint_pth)
是否要重新定义新保存的路径

整个的框架可以这样设计

import ***

g = tf.Graph()
# sess = tf.Session(graph=g)
# 这样定义会内存泄漏
# 而需要改成with形式生成sess,同时调用graph
with tf.Session(graph=g) as sess:
    """
    define your model graph
    """
    
    ***
    
    saver = tf.train.Saver() # define a global saver.
    

def some_functions():
    """
    data loader!
    """
    for iteratively: # load and save your model multi-times
       tf.reset_default_graph()
       with tf.Session(graph=g) as sess: # tf.Session() as sess:
           sess.run(tf.global_variables_initializer())
           saver_anch.restore(sess, tf.train.latest_checkpoint('Your model path here')) # restore your model.
           if retrain:
                 outs = sess.run([opt.opt_op, opt.cost, opt.accuracy, emb], feed_dict=feed_dict)
           else:
                embedding = sess.run(emb, feed_dict=feed_dict)
                
    saver.save(sess, your_model_path) # save model
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值