一、错误信息
e:\ProgramData\Anaconda3\envs\nlp\lib\site-packages\tensorflow\python\training\saver.py in restore(self, sess, save_path)
1289 # a helpful message (b/110263146)
1290 raise _wrap_restore_error_with_msg(
-> 1291 err, "a Variable name or other graph key that is missing")
1292 # This is an object-based checkpoint. We'll print a warning and then do
1293 # the restore.
NotFoundError: Restoring from checkpoint failed. This is most likely due to a Variable name or other graph key that is missing from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:
Key beta1_power_3 not found in checkpoint
[[node save/RestoreV2 (defined at <ipython-input-6-ce8858e0c03f>:17) ]]
二、报错分析
"Key xxx not found"
看到变量名字,我就知道了为什么会报错......
原来的模型写完之后跑了一次save了,然后我现在想对loss加一个正则项。
这个正则项对应几个新变量,在原来的图中是没有的。
这个restore笨就笨在,既然你在ckpt文件里找不到我定义的新变量,你就把能找到的老变量先读进来啊......
三、解决方案
研究了一下怎么解决。
原来只需要给savor声明要restore哪些变量就可以了。
而这个声明在savor的__init__方法里......
def __init__(self,
var_list=None,
reshape=False,
sharded=False,
max_to_keep=5,
keep_checkpoint_every_n_hours=10000.0,
name=None,
restore_sequentially=False,
saver_def=None,
builder=None,
defer_build=False,
allow_empty=False,
write_version=saver_pb2.SaverDef.V2,
pad_step_number=False,
save_relative_paths=False,
filename=None):
""" var_list: A list of `Variable`/`SaveableObject`, or a dictionary mapping
names to `SaveableObject`s. If `None`, defaults to the list of all
saveable objects.
"""
self._var_list = var_list
看注释部分,如果你不指定var_list,默认会把当前图里面所有可以存的变量都弄进来。
也就是说,出错的核心原因在于我的这句
saver = tf.train.Saver()
#实例化时没有指定var_list
#于是默认list中包含了当前图的所有变量
在调用restore的时候,它本质上是用self._var_list里的key,逐一地去ckpt文件里面找。
因为我没有指定var_list,所以self._var_list里面,含有新定义的几个变量。
这几个变量自然在ckpt里面找不到了。
看源码:
def object_graph_key_mapping(checkpoint_path):
"""Return name to key mappings from the checkpoint.
Args:
checkpoint_path: string, path to object-based checkpoint
Returns:
Dictionary mapping tensor names to checkpoint keys.
"""
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
print("reader type is " + str(type(reader)))
object_graph_string = reader.get_tensor(
checkpointable.OBJECT_GRAPH_PROTO_KEY)
print( type(object_graph_string))
print(object_graph_string)
object_graph_proto = (
checkpointable_object_graph_pb2.CheckpointableObjectGraph())
object_graph_proto.ParseFromString(object_graph_string)
names_to_keys = {}
print(succeed)
for node in object_graph_proto.nodes:
for attribute in node.attributes:
names_to_keys[attribute.full_name] = attribute.checkpoint_key
return names_to_keys
所以呢,只要在一开始声明你想restore哪几个变量就可以了。
"""
#假定已经获得了var_list
#只需做如下声明
"""
saver=tf.train.Saver(variables_to_restore)
with tf.Session() as sess:
savor.restore(sess, ckpt_path)
如此一来,这个savor就知道要找哪些变量了。
//如何获取var_list参考我写的这篇Tensorflow 获取model中的变量列表
//其他参考TensorFlow restore部分变量
四、其他思路
声明:这个方法我没有做过严谨的数值检验,只是测试的时候发现好像可行。
"""
#我在上面建了一个新模型,存在默认图中
"""
#在默认图中定义saver
with tf.Session(graph=tf.get_default_graph()) as sess:
saver = tf.train.Saver()
#定义新图g1
g1=tf.Graph()
#在g1中定义new_saver,但是用saver来执行restore
with tf.Session(graph=g1) as sess:
new_saver = tf.train.import_meta_graph('e:/20190227_01/mytrain/results/20190227_01/model/model.meta')
print(len(saver._var_list))
saver.restore(sess, 'e:/20190227_01/mytrain/results/20190227_01/model/model')
var_list=tf.global_variables()
t_list = saver._var_list
print(len(var_list))
print(len(t_list))
print(sess.run('g_b2:0'))
#output:
117
INFO:tensorflow:Restoring parameters from e:/20190227_01/mytrain/results/20190227_01/model/model
97
117
[ 0.01046147 0.03958455 -0.02151166 0.02771805 0.01447319 -0.0239599
-0.02140592 -0.00140079 -0.03730477 -0.02490073 0.0016868 0.02542606
-0.00882351 -0.00958596 0.02813776 0.0163886 -0.00110352 0.00578698
0.00511657 -0.008697 -0.00192673 0.00853274 0.00700764 -0.00846427
-0.02351797 0.00648104 -0.00412801 -0.02708812 0.00615729 -0.02213575
0.0020713 0.02355388 0.01304166 0.02849897 0.01101353 -0.0287178
0.00345481 -0.00946883 -0.02204182 0.00912443 -0.01256032 -0.01778828
-0.00084766 0.00020414 -0.00465849 -0.02098195 0.00441032 -0.00695921
0.0279315 -0.02065218 -0.02291382 -0.00816657 0.03027197 -0.01709834
0.00097829 0.00969903 -0.03033545 -0.01297254 -0.01039428 0.00054954
-0.01960606 0.00232905 -0.01519394 -0.0269644 0.01490234 -0.03364281
-0.02138902 -0.00395826 0.02313505 -0.00587205 -0.00563756 -0.00737901
0.02609836 -0.01434103 -0.01351835 -0.01616468 -0.01189714 -0.0319529
0.01048764 0.00732109 0.02988533 0.00657283 -0.0073387 0.00664425
-0.00182724 0.00090342 0.01468307 -0.01235422 0.0088566 -0.01289832
-0.03251986 -0.00802612 0.01519669 -0.01438498 -0.00191538 -0.01055768
0.02161908 -0.03298979 0.01225533 -0.00624389 0.0082115 -0.03947724
-0.00044154 -0.00988308 -0.00825887 -0.00105376 0.00201531 -0.00304493
0.01679732 -0.01240462 -0.02326259 -0.0055011 -0.00459422 0.0162851
-0.02608908 -0.02271436 -0.021041 -0.01203219 0.0335 -0.00448477
0.01181199 0.00394878 -0.02386374 0.0001303 -0.00294839 0.00425095
-0.0296011 0.02793927]
这个思路很神奇!
按照我们在part3的分析,restore的过程,是去取self._var_list再一个个跟ckpt文件匹配。
那么在默认图中定义的saver(),由于未指定变量列表,包含了默认图的全部117个变量。
而新图g1,导入了老的模型,print出来可以看到,只有97个变量。
这个过程为什么没有报错???
saver.restore()取self._var_list里面的117个变量来匹配,总有20个是ckpt里面没有的吧,照理来说不是应该返回not found错误吗?
从定义在默认图的saver也可以帮助g1来restore这点看,Saver类的信息中不包含对Graph类的从属关系。
也就是说Saver类是独立于Graph类存在的。
这点可以部分解释,我们在g1中,任意使用一个Saver类的实例,都能完成restore的工作。
但是这违反了源码中的self._var_list与ckpt的map一一对应的关系。
除非是出了bug......
最后!
我来劝退一波!
放弃tensorflow吧!!
torch大法好!
只需要5行代码,解决一切tensorflow的烦恼!
//https://blog.csdn.net/m0_37615398/article/details/85042176