看上了两篇文章~想留着哈哈哈~转载自
https://blog.csdn.net/qq_32799915/article/details/80312928
https://zhuanlan.zhihu.com/p/36982683
哈哈哈
TensorFlow 输出checkpoint 中的变量名与变量值
- import os
- from tensorflow.python import pywrap_tensorflow
- model_dir="/xxxxxxxxx/model.ckpt" #checkpoint的文件位置
- # Read data from checkpoint file
- reader = pywrap_tensorflow.NewCheckpointReader(model_dir)
- var_to_shape_map = reader.get_variable_to_shape_map()
- # Print tensor name and values
- for key in var_to_shape_map:
- print("tensor_name: ", key) #输出变量名
- print(reader.get_tensor(key)) #输出变量值
输出结果:
![](https://i-blog.csdnimg.cn/blog_migrate/94cafd6a3b138174ccb277d480baf7a1.png)
这里只输出了变量名
Tensorflow修改已训练模型变量名字的方法
你是否有遇到以下几个场景,特别需要修改tensorflow已训练模型变量名字呢?
- 需要从预训练模型恢复权重,而使用框架不同导致某些层变量名字不一样,但基本的网络结构都可以一一对应的时候,如slim与tensorlayer;
- 转换模型框架,如使用某些工具转换tensorflow模型到caffe模型,因为某些变量名字与转换工具定义的BN层变量名字不一的时候;
- 想修改变量名字长度的时候;等等。
那么,我们该如何修改呢?
首先,常用的tensorflow已训练完成的模型有checkpoint和已固化为pb这两种格式,但是,由于使用了tf.contrib.framework.list_variables,目前暂且仅支持checkpoint格式;
其次,我们只希望更改变量名字,不希望动到图结构,所以,我们需要先恢复图模型只更改其中的变量名字;
最后,贴上我实验的代码及其运行效果!
# -*- coding:utf-8 -*-
#!/usr/bin/env python
'''
############################################################
rename tensorflow variable.
############################################################
'''
import tensorflow as tf
import argparse
import os
import re
def get_parser():
parser = argparse.ArgumentParser(description='parameters to rename tensorflow variable!')
parser.add_argument('--ckpt_path', type=str, help='the ckpt file where to load.')
parser.add_argument('--save_path', type=str, help='the ckpt file where to save.')
parser.add_argument('--rename_var_src', type=str, help="""Comma separated list of replace variable from""")
parser.add_argument('--rename_var_dst', type=str, help="""Comma separated list of replace variable to""")
parser.add_argument('--add_prefix', type=str, help='prefix of newname.')
args = parser.parse_args()
return args
def load_model(model_path, input_map=None):
# Check if the model is a model directory (containing a metagraph and a checkpoint file)
# or if it is a protobuf file with a frozen graph
model_exp = os.path.expanduser(model_path)
if (os.path.isfile(model_exp)):
print('not support: %s' % model_exp)
else:
print('Model directory: %s' % model_exp)
meta_file, ckpt_file = get_model_filenames(model_exp)
print('Metagraph file: %s' % meta_file)
print('Checkpoint file: %s' % ckpt_file)
saver = tf.train.import_meta_graph(os.path.join(model_exp, meta_file), input_map=input_map)
saver.restore(tf.get_default_session(), os.path.join(model_exp, ckpt_file))
return saver
def get_model_filenames(model_dir):
files = os.listdir(model_dir)
meta_files = [s for s in files if s.endswith('.meta')]
if len(meta_files) == 0:
raise ValueError('No meta file found in the model directory (%s)' % model_dir)
elif len(meta_files) > 1:
raise ValueError('There should not be more than one meta file in the model directory (%s)' % model_dir)
meta_file = meta_files[0]
ckpt = tf.train.get_checkpoint_state(model_dir)
if ckpt and ckpt.model_checkpoint_path:
ckpt_file = os.path.basename(ckpt.model_checkpoint_path)
return meta_file, ckpt_file
meta_files = [s for s in files if '.ckpt' in s]
max_step = -1
for f in files:
step_str = re.match(r'(^model-[\w\- ]+.ckpt-(\d+))', f)
if step_str is not None and len(step_str.groups()) >= 2:
step = int(step_str.groups()[1])
if step > max_step:
max_step = step
ckpt_file = step_str.groups()[0]
return meta_file, ckpt_file
def rename(args):
'''rename tensorflow variable, just for checkpoint file format.'''
replace_from = args.rename_var_src.strip().split(',')
replace_to = args.rename_var_dst.strip().split(',')
assert len(replace_from) == len(replace_to)
with tf.Session() as sess:
for var_name, _ in tf.contrib.framework.list_variables(args.ckpt_path):
# Load the variable
var = tf.contrib.framework.load_variable(args.ckpt_path, var_name)
# Set the new name
new_name = var_name
for index in range(len(replace_from)):
new_name = new_name.replace(replace_from[index], replace_to[index])
if args.add_prefix:
new_name = args.add_prefix + new_name
print('Renaming %s to %s.' % (var_name, new_name))
# Rename the variable
var = tf.Variable(var, name=new_name)
# Save the variables
saver = load_model(args.ckpt_path)
sess.run(tf.global_variables_initializer())
saver.save(sess, args.save_path)
if __name__ == '__main__':
args = get_parser()
rename(args)
将代码复制到文件rename_tf_variable.py,然后参照如下命令行执行脚本:
python rename_tf_variable.py --ckpt_path ~/github/MobileFaceNet_local/output/ckpt/ --save_path /home/xsr-ai/scripts/ckpt --rename_var_src gamma,moving_mean,moving_variance --rename_var_dst scale,mean,variance
这个命令是想将batch norm层的“gamma,moving_mean,moving_variance”改为“scale,mean,variance”,执行效果如下:
![](https://i-blog.csdnimg.cn/blog_migrate/295488ad034f3aafc6233941612f2984.jpeg)