模型存储的参数恢复
模型存取介绍
见该博客。
模型恢复
使用代码
def get_restorer():
checkpoint_path = tf.train.latest_checkpoint(os.path.join(FLAGS.trained_checkpoint, cfgs.VERSION))
if checkpoint_path != None:
if RESTORE_FROM_RPN:
print('___restore from rpn___')
model_variables = slim.get_model_variables()
restore_variables = [var for var in model_variables if not var.name.startswith('Fast_Rcnn')] + [slim.get_or_create_global_step()]
for var in restore_variables:
print(var.name)
restorer = tf.train.Saver(restore_variables)
else:
restorer = tf.train.Saver()
print("model restore from :", checkpoint_path)
else:
checkpoint_path = FLAGS.pretrained_model_path
print("model restore from pretrained mode, path is :", checkpoint_path)
model_variables = slim.get_model_variables()
restore_variables = [var for var in model_variables if
(var.name.startswith(cfgs.NET_NAME)
and not var.name.startswith('{}/logits'.format(cfgs.NET_NAME)))]
for var in restore_variables:
print(var.name)
restorer = tf.train.Saver(restore_variables)
return restorer, checkpoint_path
需在FLAGS.pretrained_model_path
给出预训练模型参数的路径。
.ckpt文件与{.ckpt.meta, .ckpt.index, .ckpt.data-00000-of-00001}的区别
两者没有区别,只是在设置路径时(FLAGS.pretrained_model_path
),对于.ckpt
文件,路径需写至文件夹路径/.ckpt
;而对于{.ckpt.meta, .ckpt.index, .ckpt.data-00000-of-00001}
,路径需写至文件夹路径/.ckpt-160186
,其中.ckpt-160186
是.meta
的前缀。