需求如下:先用Resnet50在ImageNet上预训练,最后一层输出为类别数量,设为1000。然后将保存下来的参数迁移到PascalVoc上训练。
问题:由于PascalVoc只有20类,所以Resnet50最后一层输出要改为20。此时直接用tf.train.Saver()的restore,因为预训练的参数最后一层resnet50/fc长度为1000,而新模型最后一层resnet50/fc长度为20,会不匹配造成加载失败。
解决方法:利用tf.contrib.framework.get_variables_to_restore()函数,代码如下
variables_to_restore = tf.contrib.framework.get_variables_to_restore(exclude=['resnet50/fc']) saver = tf.train.Saver(variables_to_restore) with tf.Session() as sess: se