如何用slim进行网络模型迁移

# encoding:utf-8
# 主要是用于迁移初始化模型
import tensorflow as tf
checkpoint_path_mata = "*/checkpoint/model.ckpt-65559.meta" #模型网络结构
checkpoint_path_data = "*/model.ckpt-65559"#模型网络参数
vgg_path = "*/checkpoint/vgg_16.ckpt" #这个是需用来迁移的模型参数 ,这里举例vgg网络
with tf.Session() as sess:
    new_saver = tf.train.import_meta_graph(checkpoint_path_mata) #导入模型网络解构,这里只是掩饰,也可以自己直接搭建
#     new_saver.restore(sess, checkpoint_path_data) #导入模型参数,为了掩饰模型迁移初始化,也可以不导入

slim = tf.contrib.slim

checkpoint_exclude_scopes="text_box_300/conv6,text_box_300/conv7,text_box_300/conv8,text_box_300/conv9,\
                           text_box_300/conv10,text_box_300/global,text_box_300/conv4_box,\
                           text_box_300/conv7_box,text_box_300/conv8_box,text_box_300/conv9_box,\
                           text_box_300/conv10_box,text_box_300/conv11_box" 
##### checkpoint_exclude_scopes 不需要初始化的网络层
exclude =[]#把不要初始化的网络层提取出来
for item in checkpoint_exclude_scopes.split(","):
    exclude.append(item.strip())

#根据字符串开头匹配,提取需要初始化的网络层 var 
variables_to_restore=[]
for var  in slim.get_model_variables():
    #tf.logging.info(var) #打印出网络层的信息
    excluded =False
    for exclusion in exclude:
        if var.op.name.startswith(exclusion):
            excluded = True
            break
    if not excluded:
        variables_to_restore.append(var)

# 提取替换信息的映射map
# Key :vgg的网络名称    value:主网络需要初始化的网络层参数变量
model_name = "text_box_300" #主网络名称
checkpoint_model_scope = "vgg_16"#用来初始化主网络的网络名称
variables_to_restore = \
            {var.op.name.replace(model_name,
                                 checkpoint_model_scope): var
             for var in variables_to_restore}
    
#------------------初始化------------------------------------------------------------------
    #方法(1)

# 直接初始化
slim.assign_from_checkpoint(
    vgg_path,
    variables_to_restore,
    ignore_missing_vars =True 
)
# ignore_missing_vars =True 表示允许 variables_to_restore中存在一些 vgg_path中没有的key
"""
列如:WARNING:tensorflow:Checkpoint is missing variable [vgg_16/conv9/conv3x3/weights]
    WARNING:tensorflow:Checkpoint is missing variable [vgg_16/conv9/conv3x3/weights]
    WARNING:tensorflow:Checkpoint is missing variable [vgg_16/conv6/biases]
    WARNING:tensorflow:Checkpoint is missing variable [vgg_16/conv6/biases]

"""
# ignore_missing_vars =False 表示不允许 variables_to_restore中存在一些 vgg_path中没有的key
"""
报错列如:
    ValueError: Checkpoint is missing variable [vgg_16/conv10/conv3x3/biases]

所以可以通过设置ignore_missing_vars =False 检查网络搭建是否正确

"""
#方法(2)
# 返回初始化函数
return slim.assign_from_checkpoint_fn(
    vgg_path,
    variables_to_restore,
    ignore_missing_vars =True 
)
"""
作为
slim.learning.train(
        train_tensor,
        logdir=FLAGS.train_dir,
        master='',
        is_chief=True,
        init_fn=tf_utils.get_init_fn(*),##############作为训练时的初始化函数
        summary_op=summary_op,
        number_of_steps=FLAGS.max_number_of_steps,
        log_every_n_steps=FLAGS.log_every_n_steps,
        save_summaries_secs=FLAGS.save_summaries_secs,
        saver=saver,
        save_interval_secs=FLAGS.save_interval_secs,
        session_config=config,
        sync_optimizer=None)
"""
 
    
    
    

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值