tf.train.Saver(tf.contrib.framework.get_variables_to_restore(include=…,exclude=…))控制变量加载,即迁移学习
踩坑原因:
使用tf.train.Saver()创建一个Saver对象,因为使用了预训练模型,所以设置了Saver中需要加载和不需要加载的变量,但是我最后用这个Saver对象去保存训练得到的ckpt文件,结果就是ckpt文件中只保存了include中包含的变量,或者是保存除了exclude中包含的变量的其他变量
造成的bug:
重新加载前面保存的ckpt模型文件时,很多变量找不到,模型初始化失败,例如我的就是找不到global_step变量:
正常ckpt中包含global_step
只用加载时的Saver对象保存ckpt,就没有global_step
解决办法或原理
训练模型时重新创建一个Saver对象,如saver_to_save=tf.train.Saver(),保存ckpt时用这个Saver对象去保存,ckpt中就会存储所有的模型参数信息,也就是说需要创建两个Saver对象,一个用来控制加载预训练模型中的参数,一个用来保存训练的模型参数
tensorflow加载预训练模型应该是以现有图中的变量为基础来从ckpt中查找,如果不指定需要加载的变量,那么tensorflow就会认为当前图中的变量在checkpoint中都存在。