题目起这么长,其实是说怎么用tensorflow做模型的保存,导入,fine-tune/再训练。
碰到问题时感觉一头雾水,只怪自己学艺不专,从CSDN大学各个大佬帖子中都是东一榔头,西一棒子。学得很不系统。
但有多少人从一开始就非常系统呢?
而当我掌握了上面所说的技能的时候,我想说,真简单!
可能是会了就不难了吧。
我就把自己的code贴在这里,希望对还在入门tensorflow的人有所帮助。
我曾遇到的问题,都是在尝试过程中犯的非常简单的问题:
AssertionError: Do not use tf.reset_default_graph() to clear nested graphs.
FailedPreconditionError (see above for traceback): Attempting to use uninitialized value conv2d_1/kernel
Not found: Key gcnmodelae/graphconvolution_1_vars/weights/Adam not found in checkpoint
Attempting to use uninitialized value Variable
NotFoundError: Restoring from checkpoint failed.
The Session graph is empty.
Cannot interpret feed_dict key as Tensor: Tensor Tensor (...) is not an element of this graph.
NotFoundError: Restoring from checkpoint failed.
入坑了tensorflow,使用图学习模型做无监督特征提取。主要是使用graph autoencoder(GAE) 提特征,然后做分类。
因为已经有了经过预训练的网络,就想要使用预训练的网络权值做初始化,然后利用新的数据集重新训练网络。自己花了五天,踩了挺多莫名奇妙的坑,看到网上的教程很多,但是还没有一个是我这样的需求,就记录一下,分享下 踩坑经验。
总结了一下,大概有这么三个命令十分关键,大部分的问题都来自这三个命令。
1. saver = tf.train.Saver() # tf.trainable_variables()
2. sess.run(tf.global_variables_initializer())
3. saver.restore(sess_anch, checkpoint_pth)
除了这些,如何定义session和graph也挺重要。
g_anchor = tf.Graph()
sess_anch = tf.Session(graph=g_anchor)
按照我的理解,我遇到的问题,以及解决办法如下。短短几天的了解还比较有限,就持续更新吧。
1. 模型保存
saver = tf.train.Saver() # tf.trainable_variables()
位置的错放,可能导致想保存的参数保存不上。
sess.run(tf.global_variables_initializer())
位置的错放,可能导致已保存的参数不被更新,或者被重置为初始化的值。
2. 模型加载
sess.run(tf.global_variables_initializer())
位置的错放,可能导致已保存的参数不被更新。
3. 模型再训练
saver = tf.train.Saver()
加不加具体的meta文件路径,会导致模型的参数无法被更新。
4. 再训练模型保存
saver.restore(sess_anch, checkpoint_pth)
是否要重新定义新保存的路径
整个的框架可以这样设计
import ***
g = tf.Graph()
# sess = tf.Session(graph=g)
# 这样定义会内存泄漏
# 而需要改成with形式生成sess,同时调用graph
with tf.Session(graph=g) as sess:
"""
define your model graph
"""
***
saver = tf.train.Saver() # define a global saver.
def some_functions():
"""
data loader!
"""
for iteratively: # load and save your model multi-times
tf.reset_default_graph()
with tf.Session(graph=g) as sess: # tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver_anch.restore(sess, tf.train.latest_checkpoint('Your model path here')) # restore your model.
if retrain:
outs = sess.run([opt.opt_op, opt.cost, opt.accuracy, emb], feed_dict=feed_dict)
else:
embedding = sess.run(emb, feed_dict=feed_dict)
saver.save(sess, your_model_path) # save model