tensorflow || 2. tensorflow框架实现的Graph总结

tensorflow || 1. Protocol Buffer
tensorflow || 2. tensorflow框架实现的Graph总结
tensorflow || 3. graph的相关操作、保存与加载pb文件

1. Graph

在tensorflow官方文档中,Graph被定义为“一些Operation和Tensor的集合”。我们使用的python代码表达的计算,就会生成一张图。并在tensorboard中可视化

import tensorflow as tf

Bgraph = tf.Graph()  # 创建一个图Bgraph
with Bgraph.as_default():  # 将Bgraph图设为当默认图
    Bdata1 = tf.placeholder(tf.float32, name="Bdata1")
    Bdata2 = tf.placeholder(tf.float32, name="Bdata2")
    Bdata3 = tf.multiply(Bdata1, Bdata2,  name="multiply")
    tf.summary.FileWriter("./test", Bgraph)

在这里插入图片描述
上图中,每一个圆圈表示一个Op,椭圆到椭圆的边为tensor,箭头的指向表示了这张图的Op的输入输出的tensor的传递关系。
.

  • 构建Python Graph】Python代码中所描述的 Graph(简称Python Graph),包含各个op和tensor。
  • tensorflow运行时图的序列化】tensorflow先将Python Graph进行转换为Protocol Buffer(即序列化,序列化后的图称为GraphDef)。
  • session启动】通过C/C++/CUDA运行GraphDef,真实的数据计算会被放在多CPU、GPU、ARM等完成,并不是始终不变的东西。

至此,引出了GraphDef的存在。

2 GraphDef

将Python Graph进行序列化,得到的图称作GraphDef。
GraphDef由许多的叫做NodeDef的Protocol Buffer组成,NodeDef与Python Graph中的Operation相对应(也就是 tf.Operation的序列化ProtoBuf 是 NodeDef)。
GraphDef 的保存与加载的api为 tf.train.write_graph()/ tf.Import_graph_def()

import tensorflow as tf

Bgraph = tf.Graph()  # 创建一个图Bgraph
with Bgraph.as_default():  # 将Bgraph图设为当默认图
    Bdata1 = tf.placeholder(tf.float32, name="Bdata1")
    Bdata2 = tf.placeholder(tf.float32, name="Bdata2")
    Bdata3 = tf.multiply(Bdata1, Bdata2, name="multiply")

with tf.Session(graph=Bgraph) as sess:
    tf.train.write_graph(sess.graph_def, "./", "test.pb", False)
    print(sess.graph_def)

代码中的GraphDef中的内容:3个NodeDef。NodeDef中,包含name, op, input, attr。其中input是不同的node之间的连接信息。
GraphDef中不会保存Variable的信息,但可以保存Constant。
所以是将Python Graph 中的Variable替换为constant 存储在GraphDef中(训练好的weight(variable)得以保存)。所以从graphdef来恢复的图和权重,没有Variable,只能用于实际上的inference,无法用于训练。


另外,我们有时候会想以网络的中间节点 作为输出节点来保存pb模型。
其中核心api为tf.graph_util.convert_variables_to_constants()converted_graph_def.SerializeToString()

import tensorflow as tf

input_a = tf.placeholder(tf.float32, (1, 224, 224,3), "l1")
input_b = tf.placeholder(tf.float32, (1, 224, 224,3), "l2")
features = tf.subtract(input_a,input_b)

print(features)
with tf.Session() as sess:
   sess.run(tf.global_variables_initializer())

   output_node_names = ["Sub"]  # 输出节点
   converted_graph_def = tf.graph_util.convert_variables_to_constants(sess,
							                              input_graph_def  = sess.graph.as_graph_def(),
                     							         output_node_names = output_node_names)
   with tf.gfile.GFile("./pb_file_test.pb", "wb") as f:
      f.write(converted_graph_def.SerializeToString())

f_save = open("./node_name.txt", "w")
for i, n in enumerate(converted_graph_def.node):
   a = ("Name of the node - {}}\n".format(n.name))
   # print("=============")
   print(a)
   f_save.writelines(a)

3 MetaGraph

MetaGraph

在GraphDef中无法得到Variable,但通过MetaGraph可以得到。
MetaGraph的官方解释:一个MetaGraph由一个计算图和其相关的元数据构成。其中包含了4中主要信息:

  1. MetaInfoDef 】存放了一些元信息,例如版本和其他用户信息
  2. GraphDef 】MetaGraph的核心内容之一
  3. SaverDef 】图的Saver信息(例:最多同时保存的checkpoint数量、需保存的tensor的名字(并不保存tensor中的实际内容))
  4. CollectionDef 】任何需要特殊注意的Python对象,需要特殊的标注以方便import_meta_graph后取回(例:“train_op”、"prediction"等)

.

集合collection

是为了方便用户对图中的变量进行管理而被创建的一个概念,通过一个string类型的key来对一组python对象进行命名的集合。这个key可以是Tensorflow在内部定义的一些key,也可以是用户自定义的名字(string)。

  • Tensorflow中内部定义的许多标准key,全部定义在了tf.GraphKeys这个类中(例:tf.graphKeys.TRAINBLE_VARIABLE、tf.GraphKey.GLOBAL_VARIABLES等)。
    tf.trainable.variables() <==> tf.get_collection(tf.GraphKey.TRAINABLE_VARIABLES)
    tf.global_variables() <==> tf.get_colleciontion(tf.GraphKey.GLOBAL_VARIABLES)
  • 对于用户定义的Key:
    pred = model_network(X)
    loss = tf.reduce_mean(..., pred, ...)
    train_op = tf.train.AdamOptimizer(learingrate).minimize(loss)
    
    对于这一段对训练过程定义的代码,用户希望特别关注pred、loss、train_op这几个操作,那么就可以使用如下代码,将这几个变量加入collection中。并命名"training_collection":
    tf.add_to_collection("training_collection", pred)
    tf.add_to_collection("training_collection", loss)
    tf.add_to_collection("training_collection", train_op)
    
    然后通过train_collect = tf.get_collection("training_collection")得到python list,其内容就是pred、loss、train_op的Tensor。 这样操作一般是为了在一个新的session中打开这张图时,方便快速获取想要的张量。
    例子:
    with tf.session() as sess:
        pred = model_network(X)
        loss = tf.reduce_mean(..., pred, ...)
        train_op = tf.train.AdamOptimizer(learingrate).minimize(loss)
        tf.add_to_collection("training_collection", train_op)
        tf.train.export_meta_graph(tf.get_default_graph(),  "my_graph.meta")
    
    #通过import_meta_graph将图恢复(同时初始化本session的default图),
    #并通过get_collection重新获取train_op,以及通过train_op来开始一段训练(sess.run())
    with tf.Session() as sess1:
        tf.train.import_meta_graph("my_graph.meta")
        train_op = tf.get_collection("training_collection")[0]
        sess1.sun(train_op)
    

从MetaGraph中恢复构建的Graph是可以被训练的。
但。MetaGraph中虽然包含Variable的信息,但是没有Variable的实际值。所以从MetaGraph中恢复的Graph,训练都是从随机初始化的值开始的,训练中的Variable的实际值都保存在checkpoint文件中,如果要从之前训练装填恢复训练,就需要从checkpoint中restore。
.
tf.export_meta_graph() / tf.import_meta_graph():用来对MetaGraph读写的API。
tf.train.saver.save(): 在保存checkpoint的同时也会保存MetaGraph。
tf.train.saver.restore(): 在恢复图时,只恢复了Variable。所以需要加上tf.import_meta_graph()来从MetaGraph中恢复Graph。一般的,训练时不需要熊MetaGraph中恢复图Graph,而是在python中构建的网络的Graph,并对其恢复Variable。

4 CheckPoint

checkpoint中全面保存了训练某时间截断的信息,包括参数、超参数、梯度等。
tf.train.saver.save() / tf.train.saver.restore():则能够完整的保存和恢复神经网络的训练。
checkpoint分为两个文件保存Variable的二进制信息:ckpt文件保存了Variable的二进制信息,index文件用于保存了ckpt文件中对应Variable的偏移量信息。

总结

tensorflow三种API所保存和恢复的Graph是不一样的

  • GraphDef 的保存与加载的api: tf.train.write_graph() / tf.Import_graph_def()
  • MetaGraph读写的api :tf.export_meta_graph() / tf.import_meta_graph()
  • checkpoint读写的api :tf.train.saver.save() / tf.train.saver.restore()

在Python中构建的Graph,使用Tensorflow运行时,是将该图序列化到 Protocol Buffer得到GraphDef,以方便在后端运行。在此过程中,图的保存、恢复、运行都通过ProtoBuf来实现的。 GraphDef、MetaGraph、Variable、Collection、Saver等都有对应的ProtoBuf定义。例:用户只能找到Node的前一个Node,却无法知道自己的输出会被那个Node接受。

#本博客第一个代码运行结果
node {
  name: "Bdata1"
  op: "Placeholder"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "shape"
    value {
      shape {
        unknown_rank: true
      }
    }
  }
}
node {
  name: "Bdata2"
  op: "Placeholder"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "shape"
    value {
      shape {
        unknown_rank: true
      }
    }
  }
}
node {
  name: "multiply"
  op: "Mul"
  input: "Bdata1"
  input: "Bdata2"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}
versions {
  producer: 38
}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值