TensorFlow 输出和修改checkpoint 中的变量名与变量

看上了两篇文章~想留着哈哈哈~转载自

https://blog.csdn.net/qq_32799915/article/details/80312928

https://zhuanlan.zhihu.com/p/36982683

哈哈哈

TensorFlow 输出checkpoint 中的变量名与变量值
[python]  view plain  copy
  1. import os  
  2. from tensorflow.python import pywrap_tensorflow  
  3. model_dir="/xxxxxxxxx/model.ckpt" #checkpoint的文件位置  
  4. # Read data from checkpoint file  
  5. reader = pywrap_tensorflow.NewCheckpointReader(model_dir)  
  6. var_to_shape_map = reader.get_variable_to_shape_map()  
  7. # Print tensor name and values  
  8. for key in var_to_shape_map:  
  9.     print("tensor_name: ", key)  #输出变量名  
  10.     print(reader.get_tensor(key))   #输出变量值  

输出结果:
这里只输出了变量名

Tensorflow修改已训练模型变量名字的方法

你是否有遇到以下几个场景,特别需要修改tensorflow已训练模型变量名字呢?

  1. 需要从预训练模型恢复权重,而使用框架不同导致某些层变量名字不一样,但基本的网络结构都可以一一对应的时候,如slim与tensorlayer;
  2. 转换模型框架,如使用某些工具转换tensorflow模型到caffe模型,因为某些变量名字与转换工具定义的BN层变量名字不一的时候;
  3. 想修改变量名字长度的时候;等等。

那么,我们该如何修改呢?

首先,常用的tensorflow已训练完成的模型有checkpoint和已固化为pb这两种格式,但是,由于使用了tf.contrib.framework.list_variables,目前暂且仅支持checkpoint格式;

其次,我们只希望更改变量名字,不希望动到图结构,所以,我们需要先恢复图模型只更改其中的变量名字;

最后,贴上我实验的代码及其运行效果!

# -*- coding:utf-8 -*-
#!/usr/bin/env python

'''
############################################################
rename tensorflow variable.
############################################################
'''

import tensorflow as tf
import argparse
import os
import re

def get_parser():
    parser = argparse.ArgumentParser(description='parameters to rename tensorflow variable!')
    parser.add_argument('--ckpt_path', type=str, help='the ckpt file where to load.')
    parser.add_argument('--save_path', type=str, help='the ckpt file where to save.')
    parser.add_argument('--rename_var_src', type=str, help="""Comma separated list of replace variable from""")
    parser.add_argument('--rename_var_dst', type=str, help="""Comma separated list of replace variable to""")
    parser.add_argument('--add_prefix', type=str, help='prefix of newname.')
    args = parser.parse_args()
    return args

def load_model(model_path, input_map=None):
    # Check if the model is a model directory (containing a metagraph and a checkpoint file)
    #  or if it is a protobuf file with a frozen graph
    model_exp = os.path.expanduser(model_path)
    if (os.path.isfile(model_exp)):
        print('not support: %s' % model_exp)
    else:
        print('Model directory: %s' % model_exp)
        meta_file, ckpt_file = get_model_filenames(model_exp)

        print('Metagraph file: %s' % meta_file)
        print('Checkpoint file: %s' % ckpt_file)

        saver = tf.train.import_meta_graph(os.path.join(model_exp, meta_file), input_map=input_map)
        saver.restore(tf.get_default_session(), os.path.join(model_exp, ckpt_file))

    return saver

def get_model_filenames(model_dir):
    files = os.listdir(model_dir)
    meta_files = [s for s in files if s.endswith('.meta')]
    if len(meta_files) == 0:
        raise ValueError('No meta file found in the model directory (%s)' % model_dir)
    elif len(meta_files) > 1:
        raise ValueError('There should not be more than one meta file in the model directory (%s)' % model_dir)
    meta_file = meta_files[0]
    ckpt = tf.train.get_checkpoint_state(model_dir)
    if ckpt and ckpt.model_checkpoint_path:
        ckpt_file = os.path.basename(ckpt.model_checkpoint_path)
        return meta_file, ckpt_file

    meta_files = [s for s in files if '.ckpt' in s]
    max_step = -1
    for f in files:
        step_str = re.match(r'(^model-[\w\- ]+.ckpt-(\d+))', f)
        if step_str is not None and len(step_str.groups()) >= 2:
            step = int(step_str.groups()[1])
            if step > max_step:
                max_step = step
                ckpt_file = step_str.groups()[0]
    return meta_file, ckpt_file

def rename(args):
    '''rename tensorflow variable, just for checkpoint file format.'''

    replace_from = args.rename_var_src.strip().split(',')
    replace_to = args.rename_var_dst.strip().split(',')

    assert len(replace_from) == len(replace_to)

    with tf.Session() as sess:
        for var_name, _ in tf.contrib.framework.list_variables(args.ckpt_path):
            # Load the variable
            var = tf.contrib.framework.load_variable(args.ckpt_path, var_name)

            # Set the new name
            new_name = var_name

            for index in range(len(replace_from)):
                new_name = new_name.replace(replace_from[index], replace_to[index])

            if args.add_prefix:
                new_name = args.add_prefix + new_name

            print('Renaming %s to %s.' % (var_name, new_name))
            # Rename the variable
            var = tf.Variable(var, name=new_name)

        # Save the variables
        saver = load_model(args.ckpt_path)
        sess.run(tf.global_variables_initializer())
        saver.save(sess, args.save_path)

if __name__ == '__main__':
    args = get_parser()
    rename(args)

将代码复制到文件rename_tf_variable.py,然后参照如下命令行执行脚本:

python rename_tf_variable.py --ckpt_path ~/github/MobileFaceNet_local/output/ckpt/ --save_path /home/xsr-ai/scripts/ckpt --rename_var_src gamma,moving_mean,moving_variance --rename_var_dst scale,mean,variance

这个命令是想将batch norm层的“gamma,moving_mean,moving_variance”改为“scale,mean,variance”,执行效果如下:




  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 11
    评论
评论 11
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值