tensorflow || 3. graph的相关操作、保存与加载pb文件

tensorflow || 1. Protocol Buffer
tensorflow || 2. tensorflow框架实现的Graph总结
tensorflow || 3. graph的相关操作、保存与加载pb文件

1 graph的创建

在tensorflow中,程序会自动创建一个图。也可以手动建立图,并对图进行一些操作。

  • tf.graph:创建空图
  • tf.get_default_graph:获取当前的默认图结构
  • tf.reset_default_graph:对图进行重置。每使用一次,重置的新图,分配不同的地址。
import tensorflow as tf

Adata = tf.constant(5)

Bgraph = tf.Graph()  # 创建一个图Bgraph
with Bgraph.as_default():  # 将Bgraph图设为当默认图
   Bdata1 = tf.constant(5, name="Bdata1")
   Bdata2 = tf.constant(10, name="Bdata2")
   Bdata3 = tf.multiply(Bdata1, Bdata2, name="add")
   Bgtmp = tf.get_default_graph()   #这里获取的图,即G1,as_default()>和get_default_graph()相呼应

Agraph = tf.get_default_graph() #当离开的Bgraph的域,默认图又恢复为最开始的图

tf.reset_default_graph()  #reset一次,为不同的图
Gtmp1 = tf.get_default_graph()
tf.reset_default_graph()
Gtmp2 = tf.get_default_graph()

print("Bdata3.graph:", Bdata3.graph)
print("Bgraph:     ", Bgraph)
print("Bgtmp:      ", Bgtmp)

print("Adata.graph:", Adata.graph)
print("Agraph:     ", Agraph)

print("Gtmp1:       ", Gtmp1)
print("Gtmp2:       ", Gtmp2)


打印结果:

Bdata3.graph: <tensorflow.python.framework.ops.Graph object at 0x7fabfa7e30b8>
Bgraph:      <tensorflow.python.framework.ops.Graph object at 0x7fabfa7e30b8>
Bgtmp:       <tensorflow.python.framework.ops.Graph object at 0x7fabfa7e30b8>
Adata.graph: <tensorflow.python.framework.ops.Graph object at 0x7fac0b0e5f28>
Agraph:      <tensorflow.python.framework.ops.Graph object at 0x7fac0b0e5f28>
Gtmp1:        <tensorflow.python.framework.ops.Graph object at 0x7fabfa7f2198>
Gtmp2:        <tensorflow.python.framework.ops.Graph object at 0x7fabfa7f2400>

根据上面的运行结果可以看出:

  • tf.Graph()创建的图Agraph,与全局默认的图是独立的。
  • 在Agraph下,使用tf.get_default_graph(),获取的是Agraph;当离开了Agraph,再使用tf.get_default_graph()获取的是全局默认的图。
  • tf.reset_default_graph() 每使用一次,重置的新图,分配不同的地址。使用该函数,要保证当前图中的资源都已经全部进行了释放,否则会报错。
p = Predict(return_elements, pb_file)

2 graph的操作

可以在图中通过名字得到其对应的元素,比如获取变量或OP等

import tensorflow as tf

Bgraph = tf.Graph() 
with Bgraph.as_default(): 
   Bdata1 = tf.constant(5, name="Bdata1")
   Bdata2 = tf.constant(10, name="Bdata2")
   Bdata3 = tf.multiply(Bdata1, Bdata2, name="multiply")

R1 = Bgraph.get_tensor_by_name("multiply:0")
R2 = Bgraph.get_operation_by_name("multiply")

print("Bdata2:", Bdata2)
print("Bdata3:", Bdata3)
print(R1)
print(R2)

with tf.Session(graph=Bgraph) as sess:
   print(sess.run(R1))
   print(sess.run(R2.outputs[0]))

注意区别:

  • graph.get_operation_by_name("multiply"):从图中获取到命名为"multiply"的op,这个节点的输出的第0个为Bdata3
  • Bgraph.get_tensor_by_name("multiply:0"):从图中获取到命名为"multiply:0"的tensor,该tensor为"multiply"节点的第0个输出,与 graph.get_operation_by_name("multiply").outputs[0]同效。


打印结果:

Bdata2: Tensor("Bdata2:0", shape=(), dtype=int32)
Bdata3: Tensor("multiply:0", shape=(), dtype=int32)
R1:     Tensor("multiply:0", shape=(), dtype=int32)
R2:     name: "multiply"
op: "Mul"
input: "Bdata1"
input: "Bdata2"
attr {
 key: "T"
 value {
   type: DT_INT32
 }
}

50
50

3 graph的序列化

保存pb文件的

  • 1
    sess.graph.as_graph_def():将图进行序列化
    tf.train.write_graph():将序列化的图保存。一般不适用这个api,这样保存的模型,只能用于测试。(一般保存为 .ckpt 文件,用于继续训练或测试,直到效果ok再进行保存pb文件 )
# g1的图定义,包含pb的path, pb文件名,是否是文本默认False  
tf.train.write_graph(sess.graph.as_graph_def(), './', 'graph.pb', False)  

  • 2
    tf.graph_util.convert_variables_to_constants:该函数,会将计算图中的变量取值以常量的形式保存。在保存模型文件的时候,这里保存了从输入层到输出层的计算过程的Graph_Def。其余所有不涉及的前向传播和所有的反向传播(梯度)都会被舍弃。
    在保存的时候,参数output_node_names指定保存的节点名称(不是张量的名称)
    graph_def.SerializeToString():将序列化的模型转换为字符串,用于写入文件。
output_node_names = [
			# "input/input_data",  # 输入的节点这里不用提供
			"pred_sbbox/concat_2", "pred_mbbox/concat_2", "pred_lbbox/concat_2"]
converted_graph_def = tf.graph_util.convert_variables_to_constants(sess,
                           input_graph_def  = sess.graph.as_graph_def(),
                           output_node_names = output_node_names)

with tf.gfile.GFile("./yolov3_coco.pb", "wb") as f:
   f.write(converted_graph_def.SerializeToString())

导入pb文件思路

  • 使用tf.Graphtf.GraphDef分别创建个Graph、GraphDef
  • 用tensorflow的api tf.gfile.FastGFile(pb_file, 'rb')打开pb文件,并使用.ParseFromString(f.read())导入到GraphDef中
  • 使用tf.import_graph_def将GraphDef中的网络结构和参数导入到Graph中。
pb_file         = "./yolov3_coco.pb"
return_elements = ["input/input_data:0", "pred_sbbox/concat_2:0", >"pred_mbbox/concat_2:0", "pred_lbbox/concat_2:0"]
graph           = tf.Graph()

with tf.gfile.FastGFile(pb_file, 'rb') as f:
   frozen_graph_def = tf.GraphDef()
   frozen_graph_def.ParseFromString(f.read())

with Graph.as_default():
   return_elements = tf.import_graph_def(frozen_graph_def, return_elements=return_elements)
                                        
# x = graph.get_tensor_by_name("input/input_data:0")  # 这样获取的x与return_elements[0]为相同的东西
                                 
#all_ops = [op for op in graph.get_operations()]
#cons_ops = [op for op in graph.get_operations() if op.type=='Const']
#print(len(all_ops))
#print(len(cons_ops))

with tf.Session(graph=graph) as sess:
   pred_sbbox, pred_mbbox, pred_lbbox = sess.run( [return_tensors[1], return_tensors[2], return_tensors[3]],
                                                  feed_dict={ return_tensors[0]: image_data})
  • tf.GraphDef() 创建一个序列化的空图
  • tf.Graph().as_default() 创建个图
  • tf.GraphDef().ParseFromString(f.read("pb_file")) 把pb文件导入到创建的序列化空图中
  • tf.import_graph_def() 将序列化图导入到当前图中,并且返回tensor或op或不返回,内容根据入参而定。


实际自己编写读取pb文件进行预测时,为了代码的可观性,会定义成类,方便调用,这里记录下:

class Predict():

   def __init__(self,return_elements1,  pb_file1):
       self.return_elements = return_elements1
       self.pb_file = pb_file1
       self.graph = tf.Graph()
       self.return_tensors = self.read_pb_return_tensors(self.graph, self.pb_file, self.return_elements)

   def read_pb_return_tensors(self, graph, pb_file, return_elements):
       with tf.gfile.FastGFile(pb_file, 'rb') as f:
           frozen_graph_def = tf.GraphDef()
           frozen_graph_def.ParseFromString(f.read())
           # for i, n in enumerate(frozen_graph_def.node):
           #     print("Name of the node - %s" % n.name)

       with graph.as_default():
           return_elements = tf.import_graph_def(frozen_graph_def, return_elements=return_elements)
       return return_elements
       
pb_file         = "./yolov3_coco.pb"
return_elements = ["input/input_data:0", "pred_sbbox/concat_2:0", "pred_mbbox/concat_2:0", "pred_lbbox/concat_2:0"]
p = Predict(return_elements, pb_file)
with tf.Session(graph=p.graph) as sess:
    outputD = sess.run(p.return_tensors[1:-1], feed_dict={p.return_tensors[0]: ****})
  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值