tensorflow || 1. Protocol Buffer
tensorflow || 2. tensorflow框架实现的Graph总结
tensorflow || 3. graph的相关操作、保存与加载pb文件
1 graph的创建
在tensorflow中,程序会自动创建一个图。也可以手动建立图,并对图进行一些操作。
tf.graph
:创建空图tf.get_default_graph
:获取当前的默认图结构tf.reset_default_graph
:对图进行重置。每使用一次,重置的新图,分配不同的地址。import tensorflow as tf Adata = tf.constant(5) Bgraph = tf.Graph() # 创建一个图Bgraph with Bgraph.as_default(): # 将Bgraph图设为当默认图 Bdata1 = tf.constant(5, name="Bdata1") Bdata2 = tf.constant(10, name="Bdata2") Bdata3 = tf.multiply(Bdata1, Bdata2, name="add") Bgtmp = tf.get_default_graph() #这里获取的图,即G1,as_default()>和get_default_graph()相呼应 Agraph = tf.get_default_graph() #当离开的Bgraph的域,默认图又恢复为最开始的图 tf.reset_default_graph() #reset一次,为不同的图 Gtmp1 = tf.get_default_graph() tf.reset_default_graph() Gtmp2 = tf.get_default_graph() print("Bdata3.graph:", Bdata3.graph) print("Bgraph: ", Bgraph) print("Bgtmp: ", Bgtmp) print("Adata.graph:", Adata.graph) print("Agraph: ", Agraph) print("Gtmp1: ", Gtmp1) print("Gtmp2: ", Gtmp2)
打印结果:Bdata3.graph: <tensorflow.python.framework.ops.Graph object at 0x7fabfa7e30b8> Bgraph: <tensorflow.python.framework.ops.Graph object at 0x7fabfa7e30b8> Bgtmp: <tensorflow.python.framework.ops.Graph object at 0x7fabfa7e30b8> Adata.graph: <tensorflow.python.framework.ops.Graph object at 0x7fac0b0e5f28> Agraph: <tensorflow.python.framework.ops.Graph object at 0x7fac0b0e5f28> Gtmp1: <tensorflow.python.framework.ops.Graph object at 0x7fabfa7f2198> Gtmp2: <tensorflow.python.framework.ops.Graph object at 0x7fabfa7f2400>
根据上面的运行结果可以看出:
- tf.Graph()创建的图Agraph,与全局默认的图是独立的。
- 在Agraph下,使用tf.get_default_graph(),获取的是Agraph;当离开了Agraph,再使用tf.get_default_graph()获取的是全局默认的图。
- tf.reset_default_graph() 每使用一次,重置的新图,分配不同的地址。使用该函数,要保证当前图中的资源都已经全部进行了释放,否则会报错。
p = Predict(return_elements, pb_file)
2 graph的操作
可以在图中通过名字得到其对应的元素,比如获取变量或OP等
import tensorflow as tf Bgraph = tf.Graph() with Bgraph.as_default(): Bdata1 = tf.constant(5, name="Bdata1") Bdata2 = tf.constant(10, name="Bdata2") Bdata3 = tf.multiply(Bdata1, Bdata2, name="multiply") R1 = Bgraph.get_tensor_by_name("multiply:0") R2 = Bgraph.get_operation_by_name("multiply") print("Bdata2:", Bdata2) print("Bdata3:", Bdata3) print(R1) print(R2) with tf.Session(graph=Bgraph) as sess: print(sess.run(R1)) print(sess.run(R2.outputs[0]))
注意区别:
graph.get_operation_by_name("multiply")
:从图中获取到命名为"multiply"的op,这个节点的输出的第0个为Bdata3Bgraph.get_tensor_by_name("multiply:0")
:从图中获取到命名为"multiply:0"的tensor,该tensor为"multiply"节点的第0个输出,与graph.get_operation_by_name("multiply").outputs[0]
同效。
打印结果:Bdata2: Tensor("Bdata2:0", shape=(), dtype=int32) Bdata3: Tensor("multiply:0", shape=(), dtype=int32) R1: Tensor("multiply:0", shape=(), dtype=int32) R2: name: "multiply" op: "Mul" input: "Bdata1" input: "Bdata2" attr { key: "T" value { type: DT_INT32 } } 50 50
3 graph的序列化
保存pb文件的
- 1
–sess.graph.as_graph_def()
:将图进行序列化
–tf.train.write_graph()
:将序列化的图保存。一般不适用这个api,这样保存的模型,只能用于测试。(一般保存为 .ckpt 文件,用于继续训练或测试,直到效果ok再进行保存pb文件 )# g1的图定义,包含pb的path, pb文件名,是否是文本默认False tf.train.write_graph(sess.graph.as_graph_def(), './', 'graph.pb', False)
- 2
–tf.graph_util.convert_variables_to_constants
:该函数,会将计算图中的变量取值以常量的形式保存。在保存模型文件的时候,这里保存了从输入层到输出层的计算过程的Graph_Def。其余所有不涉及的前向传播和所有的反向传播(梯度)都会被舍弃。
在保存的时候,参数output_node_names指定保存的节点名称(不是张量的名称)
–graph_def.SerializeToString()
:将序列化的模型转换为字符串,用于写入文件。output_node_names = [ # "input/input_data", # 输入的节点这里不用提供 "pred_sbbox/concat_2", "pred_mbbox/concat_2", "pred_lbbox/concat_2"] converted_graph_def = tf.graph_util.convert_variables_to_constants(sess, input_graph_def = sess.graph.as_graph_def(), output_node_names = output_node_names) with tf.gfile.GFile("./yolov3_coco.pb", "wb") as f: f.write(converted_graph_def.SerializeToString())
导入pb文件思路:
- 使用
tf.Graph
、tf.GraphDef
分别创建个Graph、GraphDef- 用tensorflow的api
tf.gfile.FastGFile(pb_file, 'rb')
打开pb文件,并使用.ParseFromString(f.read())
导入到GraphDef中- 使用
tf.import_graph_def
将GraphDef中的网络结构和参数导入到Graph中。pb_file = "./yolov3_coco.pb" return_elements = ["input/input_data:0", "pred_sbbox/concat_2:0", >"pred_mbbox/concat_2:0", "pred_lbbox/concat_2:0"] graph = tf.Graph() with tf.gfile.FastGFile(pb_file, 'rb') as f: frozen_graph_def = tf.GraphDef() frozen_graph_def.ParseFromString(f.read()) with Graph.as_default(): return_elements = tf.import_graph_def(frozen_graph_def, return_elements=return_elements) # x = graph.get_tensor_by_name("input/input_data:0") # 这样获取的x与return_elements[0]为相同的东西 #all_ops = [op for op in graph.get_operations()] #cons_ops = [op for op in graph.get_operations() if op.type=='Const'] #print(len(all_ops)) #print(len(cons_ops)) with tf.Session(graph=graph) as sess: pred_sbbox, pred_mbbox, pred_lbbox = sess.run( [return_tensors[1], return_tensors[2], return_tensors[3]], feed_dict={ return_tensors[0]: image_data})
tf.GraphDef()
创建一个序列化的空图tf.Graph().as_default()
创建个图tf.GraphDef().ParseFromString(f.read("pb_file"))
把pb文件导入到创建的序列化空图中tf.import_graph_def()
将序列化图导入到当前图中,并且返回tensor或op或不返回,内容根据入参而定。
实际自己编写读取pb文件进行预测时,为了代码的可观性,会定义成类,方便调用,这里记录下:
class Predict(): def __init__(self,return_elements1, pb_file1): self.return_elements = return_elements1 self.pb_file = pb_file1 self.graph = tf.Graph() self.return_tensors = self.read_pb_return_tensors(self.graph, self.pb_file, self.return_elements) def read_pb_return_tensors(self, graph, pb_file, return_elements): with tf.gfile.FastGFile(pb_file, 'rb') as f: frozen_graph_def = tf.GraphDef() frozen_graph_def.ParseFromString(f.read()) # for i, n in enumerate(frozen_graph_def.node): # print("Name of the node - %s" % n.name) with graph.as_default(): return_elements = tf.import_graph_def(frozen_graph_def, return_elements=return_elements) return return_elements pb_file = "./yolov3_coco.pb" return_elements = ["input/input_data:0", "pred_sbbox/concat_2:0", "pred_mbbox/concat_2:0", "pred_lbbox/concat_2:0"] p = Predict(return_elements, pb_file) with tf.Session(graph=p.graph) as sess: outputD = sess.run(p.return_tensors[1:-1], feed_dict={p.return_tensors[0]: ****})