Tensorflow 模型文件的结构、模型中Tensor查看

初次涉及Tensorflow 模型,转载自https://blog.csdn.net/c20081052/article/details/82961988

深度学习网络训练后保存的模型主要包含两部分,一是网络结构的定义(网络图),二是网络结构里的参数值。Tensorflow作为深度学习框架的一种,未能免俗。

save后留存的文件格式

.meta文件

.meta 文件以 “protocol buffer”格式保存了整个模型的结构图,模型上定义的操作等信息,这个文件保存了网络结构的定义。

按大小看:model.ckpt-3072.meta ,大小是 2.9 MB。

.data-00000-of-00001 文件和 .index 文件

.data-00000-of-00001 文件和 .index 文件合在一起组成了 ckpt 文件,保存了网络结构中所有 权重和偏置 的数值。.data文件保存的是变量值,.index文件保存的是.data文件中数据和 .meta文件中结构图之间的对应关系。

按大小看:model.ckpt-3072.data-00000-of-00001,大小是 3.7 MB ; model.ckpt-3072.index ,大小是 15.5 KB。

checkpoint文件

checkpoint是一个文本文件,记录了训练过程中在所有中间节点上保存的模型的名称,首行记录的是最后(最近)一次保存的模型名称。

按大小看:checkpoint ,大小是 271字节。

查看 ckpt 模型文件中保存的 Tensor信息

查询变量名称和值:

################ 
# This code used to check msg of Tensor stored in ckpt
# work well with tensorflow version of 'v1.3.0-rc2-20-g0787eee'
################
 
import os
from tensorflow.python import pywrap_tensorflow
 
# code for finall ckpt
# checkpoint_path = os.path.join('~/tensorflowTraining/ResNet/model', "model.ckpt")
 
# code for designated ckpt, change 3890 to your num
checkpoint_path = os.path.join('~/tensorflowTraining/ResNet/model', "model.ckpt-3890")
# Read data from checkpoint file
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
# Print tensor name and values
for key in var_to_shape_map:
    print("tensor_name: ", key)
    print(reader.get_tensor(key))

部分输出信息截图,卷积层:
因为这个部分个人并没有测试,直接拷贝的图片,因此在此不做详细的展开与分析。

查看TensorFlow中checkpoint内变量的几种方法:

查看ckpt中变量的方法有三种:

在有model的情况下,使用tf.train.Saver进行restore
使用tf.train.NewCheckpointReader直接读取ckpt文件,这种方法不需要model。
使用tools里的freeze_graph来读取ckpt

注意:
如果模型保存为.ckpt的文件,则使用该文件就可以查看.ckpt文件里的变量。ckpt路径为 model.ckpt
如果模型保存为.ckpt-xxx-data (图结构)、.ckpt-xxx.index (参数名)、.ckpt-xxx-meta (参数值)文件,则需要同时拥有这三个文件才行。并且ckpt的路径为 model.ckpt-xxx

基于model来读取ckpt文件里的变量

1.首先建立model
2.从ckpt中恢复变量

with tf.Graph().as_default() as g: 
  #建立model
  images, labels = cifar10.inputs(eval_data=eval_data) 
  logits = cifar10.inference(images) 
  top_k_op = tf.nn.in_top_k(logits, labels, 1) 
  #从ckpt中恢复变量
  sess = tf.Session()
  saver = tf.train.Saver() #saver = tf.train.Saver(...variables...) # 恢复部分变量时,只需要在Saver里指定要恢复的变量
  save_path = 'ckpt的路径'
  saver.restore(sess, save_path) # 从ckpt中恢复变量

注意:基于model来读取ckpt中变量时,model和ckpt必须匹配。

使用tf.train.NewCheckpointReader直接读取ckpt文件里的变量
在此基础之上,使用tools.inspect_checkpoint里的print_tensors_in_checkpoint_file函数打印ckpt里的东西。

#使用NewCheckpointReader来读取ckpt里的变量
from tensorflow.python import pywrap_tensorflow
checkpoint_path = os.path.join(model_dir, "model.ckpt")
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path) #tf.train.NewCheckpointReader
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
  print("tensor_name: ", key)
  #print(reader.get_tensor(key))
#使用print_tensors_in_checkpoint_file打印ckpt里的内容
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
 
print_tensors_in_checkpoint_file(file_name, #ckpt文件名字
                 tensor_name, # 如果为None,则默认为ckpt里的所有变量
                 all_tensors, # bool 是否打印所有的tensor,这里打印出的是tensor的值,一般不推荐这里设置为False
                 all_tensor_names) # bool 是否打印所有的tensor的name

使用tools里的freeze_graph来读取ckpt

from tensorflow.python.tools import freeze_graph
 
freeze_graph(input_graph, #=some_graph_def.pb
       input_saver, 
       input_binary, 
       input_checkpoint, #=model.ckpt
       output_node_names, #=softmax
       restore_op_name, 
       filename_tensor_name, 
       output_graph, #='./tmp/frozen_graph.pb'
       clear_devices, 
       initializer_nodes, 
       variable_names_whitelist='', 
       variable_names_blacklist='', 
       input_meta_graph=None, 
       input_saved_model_dir=None, 
       saved_model_tags='serve', 
       checkpoint_version=2)
#freeze_graph_test.py讲述了怎么使用freeze_grapg。
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值