重载模型参数

tensorflow中做迁移学习时,会遇到模型参数值与网络模型参数定义不一致的问题,而无法载入模型参数。如参数名字改变,模型参数shape不一致,直接调用saver.restore()会报错。需要手动修改这个restore()函数。
这里提供两种恢复方式:
(1)第一种方法,提取出ckpt中参数与模型中的参数名字一致,shape大小也一致,然后再调用saver.restore()恢复。

restore_1(session,save_file):
    ckpt = tf.train.get_checkpoint_state(save_file)
    reader =tf.train.NewCheckpointReader(ckpt.model_checkpoint_path)
    saved_shapes = reader.get_variable_to_shape_map()
   var_names = sorted([(var.name, var.name.split(‘:’)[0])    
                                     for var in tf.global_variables()
                                    if var.name.split(':')[0] in saved_shapes])
    with tf.variable_scope(‘’, reuse=True:
        for var_name,saved_var_name in var_names:
            curr_var = name2var[saved_var_name]
            var_shape = curr_var.get_shape().as_list()
            if var_shape == saved_shapes[saved_var_name]:
                restore_vars.append(curr_var)
     saver = tf.train.Saver(restore_vars)
     saver.restore(session, ckpt.model_checkpoint_path)

(2)第二种手动恢复参数,相较第一种灵活性更高。只选取待恢复的参数,同时可对ckpt中的参数做修改。

restore_2(session, save_file):
     ckpt = tf.train.get_checkpoint_state(save_file)
     reader = tf.train.NewCheckpointReader(ckpt.model_checkpoint_path)
     saved_shapes = reader.get_variable_to_shape_map()
     var_names1 = sorted(set([str((var.name.split(':')[0]).split('/')[0])
                                for var in tf.global_variables()
                              if var.name.split(':')[0]  in  saved_shapes]))
     var_names2 = sorted(set([str((var.name.split(':')[0]).split('/')[1])
                                             for var in tf.global_variables()
                                     if var.name.split(':')[0]  in  saved_shapes]))
     for key in var_names1:
         with tf.variable_scope(key, reuse=True):
             for key1 in var_names2:
                 VAR = tf.get_variable(key1)
                 sess.run(tf.assign(VAR, reader.get_tensor(key+'/'+key1)))
              except ValueError:
                  print('ignore:',key)          
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值