TensorFlow 模型保存/加载方法

一.保存模型
tf.train.Saver()类,.save(sess, ckpt文件目录)方法 
     

参数名称 功能说明默认值
var_listSaver中存储变量集合全局变量集合
reshape加载时是否恢复变量形状True
sharded 是否将变量轮循放在所有设备上True
max_to_keep保留最近检查点个数5
restore_sequentially 是否按顺序恢复变量,模型较大时顺序恢复内存消耗小True

 

当var_list是字典形式{变量名字符串: 变量符号}时,相对应的restore也根据同样形式的字典将ckpt中的字符串对应的变量加载给程序中的符号, 如:saver = tf.train.Saver({"v/ExponentialMovingAverage":v}) 


如果Saver给定了字典作为加载方式,则按照字典来,如:saver = tf.train.Saver({"v/ExponentialMovingAverage":v}),否则每个变量寻找自己的name属性在ckpt中的对应值进行加载。

eg:


 
 
  1. #保存代码
  2. saver = tf.train.Saver(max_to_keep= 2)
  3. with tf.Session() as sess:
  4. ...
  5. saver.save(sess, '../model/model.ckpt')

output: 

checkpoint文件会记录保存信息,通过它可以定位最新保存的模型:
.meta文件保存了当前图结构
.index文件保存了当前参数名
.data文件保存了当前参数值

二.加载模型

当我们基于checkpoint文件(ckpt)加载参数时,实际上我们使用Saver.restore取代了initializer的初始化

1. 加载 图结构+模型参数


 
 
  1. ckpt = tf.train.get_checkpoint_state( './model/') #./model为数据加载路径
  2. saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path + '.meta')
  3. with tf.Session() as sess:
  4. if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
  5. saver.restore(sess,ckpt.model_checkpoint_path)

2.只加载数据,不加载图结构


 
 
  1. # 只加载数据,不加载图结构,可以在新图中改变batch_size等的值
  2. # 不过需要注意,Saver对象实例化之前需要定义好新的图结构,否则会报错
  3. saver = tf.train.Saver()
  4. ckpt = tf.train.get_checkpoint_state( './model/')
  5. with tf.Session() as sess:
  6. ... #graph 定义
  7. if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
  8. saver.restore(sess,ckpt.model_checkpoint_path)

3.函数说明: 

(1)tf.train.get_checkpoint_state()

参数名称 功能说明
checkpoint_dirchechpoint文件的路径
latest_filename指定chechpoint的名字,默认是'checkpoint'
ckpt = tf.train.get_checkpoint_state('./model/')
 
 

ckpt =tf.train.get_checkpoint_state()  通过  'checkpoint文件'  找到模型文件名
ckpt 包含的属性:
    model_checkpoint_path
        保存了'./model'中最新的tensorflow模型文件的文件名
    all_model_checkpoint_paths:      
        保存了'./model'中所有tensorflow模型文件的文件名

eg:


 
 
  1. ckpt = tf.train.get_checkpoint_state( './model/')
  2. print(ckpt.model_checkpoint_path)

output: 


 

(2)tf.train.import_meta_graph()


 
 
  1. ckpt = tf.train.get_checkpoint_state( './model/') #./model为数据加载路径
  2. saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path + '.meta')

tf.train.import_meta_graph()根据 'model.ckpt-n.meta'加载图结构,并返回saver对象

 

(3) tf.train.Saver.restore()


 
 
  1. ckpt = tf.train.get_checkpoint_state( './model/')
  2. saver.restore(sess,ckpt.model_checkpoint_path)
  3. #等价
  4. saver.restore(sess, './model/model.ckpt-0')
  5. #
  6. new_saver.restore(sess, tf.train.latest_checkpoint( './model/'))

saver.restore()回根据 'model.ckpt-n' 自动寻找参数名--值文件进行加载
 

三. 二进制模型加载

这种加载方法一般是对应网上各大公司已经训练好的网络模型进行修改的工作


 
 
  1. # 新建空白图
  2. self.graph = tf.Graph()
  3. # 空白图列为默认图
  4. with self.graph.as_default():
  5. # 二进制读取模型文件
  6. with tf.gfile.FastGFile(os.path.join(model_dir,model_name), 'rb') as f:
  7. # 新建GraphDef文件,用于临时载入模型中的图
  8. graph_def = tf.GraphDef()
  9. # GraphDef加载模型中的图
  10. graph_def.ParseFromString(f.read())
  11. # 在空白图中加载GraphDef中的图
  12. tf.import_graph_def(graph_def,name= '')
  13. # 在图中获取张量需要使用graph.get_tensor_by_name加张量名
  14. # 这里的张量可以直接用于session的run方法求值了
  15. # 补充一个基础知识,形如'conv1'是节点名称,而'conv1:0'是张量名称,表示节点的第一个输出张量
  16. self.input_tensor = self.graph.get_tensor_by_name(self.input_tensor_name)
  17. self.layer_tensors = [self.graph.get_tensor_by_name(name + ':0') for name in self.layer_operation_names]

 

四.二进制模型制作

这节是关于tensorflow的Freezing,字面意思是冷冻,可理解为整合合并;整合什么呢,就是将模型文件和权重文件整合合并为一个文件,主要用途是便于发布。

官方解释可参考:https://www.tensorflow.org/extend/tool_developers/#freezing

tensorflow在训练过程中,通常不会将权重数据保存的格式文件里(这里我理解是模型文件),反而是分开保存在一个叫checkpoint的检查点文件里,当初始化时,再通过模型文件里的变量Op节点来从checkoupoint文件读取数据并初始化变量。这种模型和权重数据分开保存的情况,使得发布产品时不是那么方便,所以便有了freeze_graph.py脚本文件用来将这两文件整合合并成一个文件。

freeze_graph.py是怎么做的呢?
1.加载模型文件,
2.从checkpoint文件读取权重数据初始化到模型里的权重变量,再将权重变量转换成权重 常量 (因为 常量 能随模型一起保存在同一个文件里)
3.再通过指定的输出节点将没用于输出推理的Op节点从图中剥离掉,再重新保存到指定的文件里(用write_graphdef或Saver)

文件目录:tensorflow/python/tools/free_graph.py
测试文件:tensorflow/python/tools/free_graph_test.py 这个测试文件很有学习价值

参数:
input_graph:(必选)
模型文件,可以是二进制的pb文件,或文本的meta文件,用input_binary来指定区分(见下面说明)

input_saver:(可选)
Saver解析器。保存模型和权限时,Saver也可以自身序列化保存,以便在加载时应用合适的版本。主要用于版本不兼容时使用。可以为空,为空时用当前版本的Saver。

input_binary:(可选)
配合input_graph用,为true时,input_graph为二进制,为false时,input_graph为文件。默认False

input_checkpoint:(必选)
检查点数据文件。训练时,给Saver用于保存权重、偏置等变量值。这时用于模型恢复变量值。

output_node_names:(必选)
输出节点的名字,有多个时用逗号分开。用于指定输出节点,将没有在输出线上的其它节点剔除。

restore_op_name:(可选)
从模型恢复节点的名字。升级版中已弃用。默认:save/restore_all

filename_tensor_name:(可选)
已弃用。默认:save/Const:0

output_graph:(必选)
用来保存整合后的模型输出文件。

clear_devices:(可选)默认True。
指定是否清除训练时节点指定的运算设备(如cpu、gpu、tpu。cpu是默认)

initializer_nodes:(可选)默认空。
权限加载后,可通过此参数来指定需要初始化的节点,用逗号分隔多个节点名字。

11、variable_names_blacklist:(可先)默认空。
变量黑名单,用于指定不用恢复值的变量,用逗号分隔多个变量名字。


 
 
  1. python tensorflow/python/tools/free_graph.py \
  2. --input_graph=some_graph_def.pb \ 注意:这里的pb文件是用tf.train.write_graph方法保存的
  3. --input_checkpoint=model.ckpt .1001 \ 注意:这里若是r12以上的版本,只需给.data -00000....前面的文件名,如:model.ckpt .1001.data -00000-of -00001,只需写model.ckpt .1001
  4. --output_graph=/tmp/frozen_graph.pb
  5. --output_node_names=softmax

另外,如果模型文件是.meta格式的,也就是说用saver.Save方法和checkpoint一起生成的元模型文件,free_graph.py不适用,但可以改造下:
1、copy free_graph.py为free_graph_meta.py
2、修改free_graph.py,导入meta_graph:from tensorflow.python.framework import meta_graph
3、将91行到97行换成:input_graph_def = meta_graph.read_meta_graph_file(input_graph).graph_def

 

参考:

https://www.cnblogs.com/hellcat/p/6925757.html

https://blog.csdn.net/changeforeve/article/details/80268522 

https://blog.csdn.net/mrr1ght/article/details/81023330

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值