在使用tensorflow保存模型后,想要加载模型,用来对测试数据进行测试时,出现了“Key Variable not fonund in checkpoint"问题。
谷歌了一下,发现有一篇博客已经解答了这个问题,链接在参考博客上。为了看看我的问题和他的问题是不是一样的,我按照该博客的方法,在加载模型的代码中进行如下修改。
import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
saver = tf.train.Saver()
reader = pywrap_tensorflow.NewCheckpointReader(model_checkpoint_path)#model_checkpoint_path是保存模型的路径加上模型名
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
print("tensor name:",key)
saver.restore(sess, model_checkpoint_path)
输出的结果:
发现确实如这篇博客所说的,是因为模型把优化的变量也给保存下来了,所以后面就按照该博客建议的,在tf.train.Saver中指定真正要保存的变量。因为我的模型比较多层,按照这篇博客那种方法就太复杂了,所以我就直接把var_list设置成tf.trainable_variables()。
import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
reader = pywrap_tensorflow.NewCheckpointReader(model_checkpoint_path)#model_checkpoint_path是保存模型的路径加上模型名
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
print("tensor name:",key)
saver.restore(sess, model_checkpoint_path)
saver = tf.train.Saver(max_to_keep=FLAGS.keep_ckpt + 1 if FLAGS.keep_ckpt else 1000000, var_list = tf.trainable_variables()) #前面的max_to_keep是确认要保存的模型的个数,这个如果不指定的话可以直接去掉,只保留后面的var_list = tf.trainable_variables()
修改完成后,再次运行训练代码,将训练完成的模型保存下来。保存完成后,再次加载模型。本来以为已经大功告成了,没想到再次报错。
再次把模型打印出来,跟之前加载的模型进行对比发现:有些tensor的tensorname已经没有Adam这个名称了,但是有些还是有。仔细观察我发现,那些没有的层,都是我训练的时候设置成可训练的层。因为我在训练的时候,是将pretrained的模型加载进来,然后对pretrained的模型的后面几层进行训练的。
发现这个问题之后,我再在测试代码中,插入下面这行代码。
saver = tf.train.Saver(max_to_keep=FLAGS.keep_ckpt + 1 if FLAGS.keep_ckpt else 1000000,var_list=tf.trainable_variables())
最后模型正常加载,可以正常进行测试了。
总结:
需要在训练和测试代码中都增加下面这段代码。
saver = tf.train.Saver(max_to_keep=FLAGS.keep_ckpt + 1 if FLAGS.keep_ckpt else 1000000,var_list=tf.trainable_variables())
参考博客