Tensorflow学习笔记(六)模型的合并

将多个模型文件合并成一个模型文件

声明: 本篇博文主要是参考这篇博文做的一些测试和改进!更多详细的细节可以参考原文。

先用之前Tensorflow学习笔记(二)模型的保存与加载(一 )中的代码生成SavedModel模型文件,如
在这里插入图片描述
这里的模型效果是输入一个x,返回x+2

定义一个简单的模型

with tf.Graph().as_default() as g_one:
	input1 = tf.placeholder(tf.float32,name='one_input')
	data = tf.Variable(3.)
	mul = tf.multiply(input1,data)
	tf.identity(mul,name='one_output')
	init = tf.global_variables_initializer()
	with tf.Session(graph=g_one) as sess:
		sess.run(init)
		g1def = graph_util.convert_variables_to_constants(
			sess,
			sess.graph_def,
			["one_output"],
			variable_names_whitelist=None,
			variable_names_blacklist=None)

这里的模型效果是,把输入x乘了一个3.0并输出,使用graph_util.convert_variables_to_constants将模型中的变量转化为常量。

加载SavedModel模型

with tf.Graph().as_default() as g_two:
	with tf.Session(graph=g_two) as sess:
		# input_graph_def = saved_model_utils.get_meta_graph_def(
		# 	"./models", tf.saved_model.tag_constants.SERVING).graph_def
		tf.saved_model.loader.load(sess, ["serve"], "./models")
		g2def = graph_util.convert_variables_to_constants(
			sess,
			sess.graph_def,
			["output"],
			variable_names_whitelist=None,
			variable_names_blacklist=None)

加载.meta模型模型

先用之前Tensorflow学习笔记(三)模型的保存与加载(二)中的代码生成SavedModel模型文件,如
在这里插入图片描述

with tf.Graph().as_default() as g_two:
	ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
	with tf.Session(graph=g_two) as sess:
		saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path + '.meta')
		saver.restore(sess, ckpt.model_checkpoint_path)
		g2def = graph_util.convert_variables_to_constants(
			sess,
			sess.graph_def,
			["output"])

连接两个模型

这里把第一个模型输出的结果当作第二个模型的输入

with tf.Graph().as_default() as g_combined:
	with tf.Session(graph=g_combined) as sess:
		x = tf.placeholder(tf.float32, name="my_input")
		y = tf.import_graph_def(g1def, input_map={"one_input:0": x}, return_elements=["one_output:0"])
		z, = tf.import_graph_def(g2def, input_map={"input:0": y}, return_elements=["output:0"])
		tf.identity(z, "my_output")
		print(sess.run(z,feed_dict={'my_input:0':3.}))

这里需要注意的是z, = tf.import_graph_def(g2def, input_map={"input:0": y}, return_elements=["output:0"]) z , 不是 z 因为这里的输出结果是个列表
在这里插入图片描述
运行结果:

在这里插入图片描述
也就是3*3+2

保存成为新的模型

这里可以用之前几篇提到的保存方法来保存新的模型,不过有些细节需要注意

保存成一个.pb模型
		g_combineddef = graph_util.convert_variables_to_constants(sess,sess.graph_def,["my_output"])
		MODEL_SAVE_PATH = "./models/"  # 保存模型的路径
		tf.train.write_graph(g_combineddef, MODEL_SAVE_PATH, 'my_model.pb', as_text=False)

这里注意缩进,在with tf.Session(graph=g_combined) as sess:的包含内,因为with 特性在结束后会关闭Session,而这里的convert_variables_to_constants用到了sess。
在这里插入图片描述
这里生成的my_model.pb可以像之前博客中一样被安卓调用。

保存为SavedModel模型
tf.saved_model.simple_save(sess, "./modelbase",inputs={"my_input": x},
					 outputs={"my_output": z})

这里是保存SavedModel模型最简单的方法,当然也可以用 之前博客中使用的标准方法。
在这里插入图片描述
这里生成的SavedModel模型可以用上一篇讲到的合并成一个.pb文件在被Android端调用,调用方法跟上面一样。
这里需要注意的地方是保存的地址文件夹,不能提前存在或者说重复创建否则就会像这样报错
在这里插入图片描述

保存为.meta模型

其实保存方法跟之前提到的是一样的,只不过这里因为变量Variable都被转化为常量constant 所以不能保存为.meta模型了!当然主要是我没有想到,如果有会的可以在评论区给我留言,一起学习交流!

完整代码

简单的模型与SavedModel模型
import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python.tools import saved_model_utils
MODEL_SAVE_PATH = "./models/" # 保存模型的路径
with tf.Graph().as_default() as g_one:
	input1 = tf.placeholder(tf.float32,name='one_input')
	data = tf.Variable(3.)
	mul = tf.multiply(input1,data)
	tf.identity(mul,name='one_output')
	init = tf.global_variables_initializer()
	with tf.Session(graph=g_one) as sess:
		sess.run(init)
		g1def = graph_util.convert_variables_to_constants(
			sess,
			sess.graph_def,
			["one_output"],
			variable_names_whitelist=None,
			variable_names_blacklist=None)

with tf.Graph().as_default() as g_two:
	ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
	with tf.Session(graph=g_two) as sess:
		saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path + '.meta')
		saver.restore(sess, ckpt.model_checkpoint_path)
		g2def = graph_util.convert_variables_to_constants(
			sess,
			sess.graph_def,
			["output"])

with tf.Graph().as_default() as g_combined:
	with tf.Session(graph=g_combined) as sess:
		x = tf.placeholder(tf.float32, name="my_input")
		y = tf.import_graph_def(g1def, input_map={"one_input:0": x}, return_elements=["one_output:0"])
		z, = tf.import_graph_def(g2def, input_map={"input:0": y}, return_elements=["output:0"])
		tf.identity(z, "my_output")
		print(sess.run(z,feed_dict={'my_input:0':3.}))


		# 保存1
		g_combineddef = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["my_output"])
		tf.train.write_graph(g_combineddef, MODEL_SAVE_PATH, 'my_model.pb', as_text=False)

		# 保存2
		# tf.saved_model.simple_save(sess,
		# 						   "./modelbase",
		# 						   inputs={"my_input": x},
		# 						   outputs={"my_output": z})
简单的模型与.meta模型模型
import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python.tools import saved_model_utils
MODEL_SAVE_PATH = "./models/" # 保存模型的路径
with tf.Graph().as_default() as g_one:
	input1 = tf.placeholder(tf.float32,name='one_input')
	data = tf.Variable(3.)
	mul = tf.multiply(input1,data)
	tf.identity(mul,name='one_output')
	init = tf.global_variables_initializer()
	with tf.Session(graph=g_one) as sess:
		sess.run(init)
		g1def = graph_util.convert_variables_to_constants(
			sess,
			sess.graph_def,
			["one_output"],
			variable_names_whitelist=None,
			variable_names_blacklist=None)

with tf.Graph().as_default() as g_two:
	with tf.Session(graph=g_two) as sess:
		tf.saved_model.loader.load(sess, ["serve"], "./models")
		g2def = graph_util.convert_variables_to_constants(
			sess,
			sess.graph_def,
			["output"],
			variable_names_whitelist=None,
			variable_names_blacklist=None)

with tf.Graph().as_default() as g_combined:
	with tf.Session(graph=g_combined) as sess:
		x = tf.placeholder(tf.float32, name="my_input")
		y = tf.import_graph_def(g1def, input_map={"one_input:0": x}, return_elements=["one_output:0"])
		z, = tf.import_graph_def(g2def, input_map={"input:0": y}, return_elements=["output:0"])
		tf.identity(z, "my_output")
		print(sess.run(z,feed_dict={'my_input:0':3.}))


		# 保存1
		g_combineddef = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["my_output"])
		tf.train.write_graph(g_combineddef, MODEL_SAVE_PATH, 'my_model.pb', as_text=False)

		# 保存2
		# tf.saved_model.simple_save(sess,
		# 						   "./modelbase",
		# 						   inputs={"my_input": x},
		# 						   outputs={"my_output": z})

当然也可以.meta模型模型与SavedModel模型,
.meta模型模型与.meta模型模型,
SavedModel模型与SavedModel模型
这些区别都不大,只是需要注意输入输出的name,这里我就不举例子了,感兴趣的可以自己尝试!

希望这篇文章对您有帮助,感谢阅读!

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值