踩坑实录:
在做迁移学习的时候经常会碰到 增加了新的层却需要调取已有模型的部分参数的情况
因为已有的checkpoint里并没有新加入层的variables,报错为:
Key xxx not found in checkpoint
可以通过get_collection/看看该层的所有variables
var = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope=‘新加入层的scope’)
或者
var=slim.get_variables('新加入层的scope')
在restore的时候,知道哪些variables是与新加入的层相关的之后,exclude这些variables就好了,(但我在解决实际问题的时候 发现tf无法正则匹配scope的关键字 比如新增加的scope为 scope1=model/inference/dense 那么这个scope属于绝对路径 比如scope2=model/optimizer/model/inference/dense 这个就得重新写入exclude list 我们只能比较前后有哪些key是不同的 然后剔除这些不同的key):
var_to_restore=tf.contrib.framework.get_variables_to_restore(exclude=['scope1','scope2'...])
saver=tf.train.Saver(var_to_restore,max_to_keep=xx)