# encoding:utf-8
# 主要是用于迁移初始化模型
import tensorflow as tf
checkpoint_path_mata = "*/checkpoint/model.ckpt-65559.meta" #模型网络结构
checkpoint_path_data = "*/model.ckpt-65559"#模型网络参数
vgg_path = "*/checkpoint/vgg_16.ckpt" #这个是需用来迁移的模型参数 ,这里举例vgg网络
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph(checkpoint_path_mata) #导入模型网络解构,这里只是掩饰,也可以自己直接搭建
# new_saver.restore(sess, checkpoint_path_data) #导入模型参数,为了掩饰模型迁移初始化,也可以不导入
slim = tf.contrib.slim
checkpoint_exclude_scopes="text_box_300/conv6,text_box_300/conv7,text_box_300/conv8,text_box_300/conv9,\
text_box_300/conv10,text_box_300/global,text_box_300/conv4_box,\
text_box_300/conv7_box,text_box_300/conv8_box,text_box_300/conv9_box,\
text_box_300/conv10_box,text_box_300/conv11_box"
##### checkpoint_exclude_scopes 不需要初始化的网络层
exclude =[]#把不要初始化的网络层提取出来
for item in checkpoint_exclude_scopes.split(","):
exclude.append(item.strip())
#根据字符串开头匹配,提取需要初始化的网络层 var
variables_to_restore=[]
for var in slim.get_model_variables():
#tf.logging.info(var) #打印出网络层的信息
excluded =False
for exclusion in exclude:
if var.op.name.startswith(exclusion):
excluded = True
break
if not excluded:
variables_to_restore.append(var)
# 提取替换信息的映射map
# Key :vgg的网络名称 value:主网络需要初始化的网络层参数变量
model_name = "text_box_300" #主网络名称
checkpoint_model_scope = "vgg_16"#用来初始化主网络的网络名称
variables_to_restore = \
{var.op.name.replace(model_name,
checkpoint_model_scope): var
for var in variables_to_restore}
#------------------初始化------------------------------------------------------------------
#方法(1)
# 直接初始化
slim.assign_from_checkpoint(
vgg_path,
variables_to_restore,
ignore_missing_vars =True
)
# ignore_missing_vars =True 表示允许 variables_to_restore中存在一些 vgg_path中没有的key
"""
列如:WARNING:tensorflow:Checkpoint is missing variable [vgg_16/conv9/conv3x3/weights]
WARNING:tensorflow:Checkpoint is missing variable [vgg_16/conv9/conv3x3/weights]
WARNING:tensorflow:Checkpoint is missing variable [vgg_16/conv6/biases]
WARNING:tensorflow:Checkpoint is missing variable [vgg_16/conv6/biases]
"""
# ignore_missing_vars =False 表示不允许 variables_to_restore中存在一些 vgg_path中没有的key
"""
报错列如:
ValueError: Checkpoint is missing variable [vgg_16/conv10/conv3x3/biases]
所以可以通过设置ignore_missing_vars =False 检查网络搭建是否正确
"""
#方法(2)
# 返回初始化函数
return slim.assign_from_checkpoint_fn(
vgg_path,
variables_to_restore,
ignore_missing_vars =True
)
"""
作为
slim.learning.train(
train_tensor,
logdir=FLAGS.train_dir,
master='',
is_chief=True,
init_fn=tf_utils.get_init_fn(*),##############作为训练时的初始化函数
summary_op=summary_op,
number_of_steps=FLAGS.max_number_of_steps,
log_every_n_steps=FLAGS.log_every_n_steps,
save_summaries_secs=FLAGS.save_summaries_secs,
saver=saver,
save_interval_secs=FLAGS.save_interval_secs,
session_config=config,
sync_optimizer=None)
"""
如何用slim进行网络模型迁移
最新推荐文章于 2019-09-29 20:31:41 发布