tensorflow保存模型的两种方法

由于神经网络训练比较复杂所以可能需要先保存训练好的模型,然后再需要的时候进行调用,下面介绍两种保存模型的方法:

方法一:使用tf.train.Saver()

保存代码,该方法保存的模型比较全,只要定义的变量均可获取,导入的模型与当前生成几乎具有一样的能力:

#定义占位符,具有名称的变量可以被在导入模型后获取
x = tf.placeholder(tf.float32, [None, 784], name ='x')
y_ = tf.placeholder(tf.int64, [None], name='y_')

#有些变量名难以定义,可以通过下面的方法保存
tf.add_to_collection('pred_network', y_conv)
tf.add_to_collection('pred_network', keep_prob)

#保存模型,目录model,前缀mnist_model
saver = tf.train.Saver()
with tf.Session() as sess:
    #训练部分,省略
	saver.save(sess, './model/mist_model')

加载:

with tf.Session() as sess:
    model = tf.train.import_meta_graph('./model/mist_model.meta')
    model.restore(sess, './model/mist_model')
    
    #加载变量,注意变量名必须是定义过的:
    #这部分因为是将变量存入一个集合中,所以需要注意顺序
	y_conv = tf.get_collection('pred_network')[0]
    keep_prob = tf.get_collection('pred_network')[1]
    
    graph = tf.get_default_graph()
    x = graph.get_operation_by_name('x').outputs[0]
    y_ = graph.get_operation_by_name('y_').outputs[0] 

方法二:冻结模型

保存代码,该方法只能保存一种类型的网络,函数根据定义的输出节点来确定网络的类型是评估或者分类,也就是根据输出节点,往前倒推,有关联的变量才存储。另外,不同的图之间不能运算,也就是说,加载的图中变量不能用于新的计算。
:如果当前网络为分类网络,那么,即使之前训练的网络中包含用于评估的标签变量,该分类网络也不能导入该标签变量。
冻结模型也有两种方法:

方法一:
from tensorflow.python.framework import graph_util
#变量名定义
x = tf.placeholder(tf.float32, [None, 784], name ='x')

with tf.Session() as sess:
    #训练部分,省略
	constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['classifier/predicition'])
     with tf.gfile.FastGFile("./model/outModel1.pb", mode='wb') as f:
        f.write(constant_graph.SerializeToString())
方法二:freeze_graph.freeze_graph()

该方法应该主要以命令行方式将存储的一系列模型文件转换成pb格式模型

from tensorflow.python.tools import freeze_graph
from tensorflow.python.framework import graph_io

# 分别保存图和变量
saver = tf.train.Saver()
checkpoint_path = saver.save(sess, "./model/mist_model")
# 以文本方式存储所有节点信息
graph_io.write_graph(sess.graph, "./model/", "model.pb")

# 定义冻结图方法的参数
input_graph_path = os.path.join("./model/", "model.pb")
input_saver_def_path = ""
input_binary = False
output_node_names = "classifier/predicition"
restore_op_name = "save/restore_all"
filename_tensor_name = "save/Const:0"
output_graph_path = os.path.join("./model/", "outModel2.pb")
clear_devices = False
input_meta_graph = "./model/mist_model.meta"
# 冻结图
freeze_graph.freeze_graph(
   input_graph_path,
   input_saver_def_path,
   input_binary,
   checkpoint_path,
   output_node_names,
   restore_op_name,
   filename_tensor_name,
   output_graph_path,
   clear_devices,
   "",
   "",
   input_meta_graph,
   checkpoint_version=1)

加载:

graph = tf.Graph()
graph_def = tf.GraphDef()
with open('./model/outModel1.pb', "rb") as f:
	graph_def.ParseFromString(f.read())
with graph.as_default():
	tf.import_graph_def(graph_def)

#变量加载,只能加载与本网络功能有关的变量。
#注意:变量名前必须包含import,否则,报错:“The name 'x' refers to an Operation not in the graph”
x = graph.get_operation_by_name('import/x').outputs[0]
keep_prob = graph.get_operation_by_name('import/dropout/keep_prob').outputs[0] 
pred = graph.get_operation_by_name('import/classifier/predicition').outputs[0]
with tf.Session(graph=graph) as sess:
 	#进行计算
  • 3
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值