tensorflow || 1. Protocol Buffer
tensorflow || 2. tensorflow框架实现的Graph总结
tensorflow || 3. graph的相关操作、保存与加载pb文件
1. Graph
在tensorflow官方文档中,Graph被定义为“一些Operation和Tensor的集合”。我们使用的python代码表达的计算,就会生成一张图。并在tensorboard中可视化
import tensorflow as tf Bgraph = tf.Graph() # 创建一个图Bgraph with Bgraph.as_default(): # 将Bgraph图设为当默认图 Bdata1 = tf.placeholder(tf.float32, name="Bdata1") Bdata2 = tf.placeholder(tf.float32, name="Bdata2") Bdata3 = tf.multiply(Bdata1, Bdata2, name="multiply") tf.summary.FileWriter("./test", Bgraph)
上图中,每一个圆圈表示一个Op,椭圆到椭圆的边为tensor,箭头的指向表示了这张图的Op的输入输出的tensor的传递关系。
.
- 【构建Python Graph】Python代码中所描述的 Graph(简称Python Graph),包含各个op和tensor。
- 【tensorflow运行时图的序列化】tensorflow先将Python Graph进行转换为Protocol Buffer(即序列化,序列化后的图称为GraphDef)。
- 【session启动】通过C/C++/CUDA运行GraphDef,真实的数据计算会被放在多CPU、GPU、ARM等完成,并不是始终不变的东西。
至此,引出了GraphDef的存在。
2 GraphDef
将Python Graph进行序列化,得到的图称作GraphDef。
GraphDef由许多的叫做NodeDef的Protocol Buffer组成,NodeDef与Python Graph中的Operation相对应(也就是 tf.Operation的序列化ProtoBuf 是 NodeDef)。
GraphDef 的保存与加载的api为tf.train.write_graph()
/tf.Import_graph_def()
import tensorflow as tf Bgraph = tf.Graph() # 创建一个图Bgraph with Bgraph.as_default(): # 将Bgraph图设为当默认图 Bdata1 = tf.placeholder(tf.float32, name="Bdata1") Bdata2 = tf.placeholder(tf.float32, name="Bdata2") Bdata3 = tf.multiply(Bdata1, Bdata2, name="multiply") with tf.Session(graph=Bgraph) as sess: tf.train.write_graph(sess.graph_def, "./", "test.pb", False) print(sess.graph_def)
代码中的GraphDef中的内容:3个NodeDef。NodeDef中,包含name, op, input, attr。其中input是不同的node之间的连接信息。
GraphDef中不会保存Variable的信息,但可以保存Constant。
所以是将Python Graph 中的Variable替换为constant 存储在GraphDef中(训练好的weight(variable)得以保存)。所以从graphdef来恢复的图和权重,没有Variable,只能用于实际上的inference,无法用于训练。
另外,我们有时候会想以网络的中间节点 作为输出节点来保存pb模型。
其中核心api为tf.graph_util.convert_variables_to_constants()
、converted_graph_def.SerializeToString()
import tensorflow as tf input_a = tf.placeholder(tf.float32, (1, 224, 224,3), "l1") input_b = tf.placeholder(tf.float32, (1, 224, 224,3), "l2") features = tf.subtract(input_a,input_b) print(features) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) output_node_names = ["Sub"] # 输出节点 converted_graph_def = tf.graph_util.convert_variables_to_constants(sess, input_graph_def = sess.graph.as_graph_def(), output_node_names = output_node_names) with tf.gfile.GFile("./pb_file_test.pb", "wb") as f: f.write(converted_graph_def.SerializeToString()) f_save = open("./node_name.txt", "w") for i, n in enumerate(converted_graph_def.node): a = ("Name of the node - {}}\n".format(n.name)) # print("=============") print(a) f_save.writelines(a)
3 MetaGraph
MetaGraph
在GraphDef中无法得到Variable,但通过MetaGraph可以得到。
MetaGraph的官方解释:一个MetaGraph由一个计算图和其相关的元数据构成。其中包含了4中主要信息:
- 【 MetaInfoDef 】存放了一些元信息,例如版本和其他用户信息
- 【 GraphDef 】MetaGraph的核心内容之一
- 【 SaverDef 】图的Saver信息(例:最多同时保存的checkpoint数量、需保存的tensor的名字(并不保存tensor中的实际内容))
- 【 CollectionDef 】任何需要特殊注意的Python对象,需要特殊的标注以方便import_meta_graph后取回(例:“train_op”、"prediction"等)
.
集合collection
是为了方便用户对图中的变量进行管理而被创建的一个概念,通过一个string类型的key来对一组python对象进行命名的集合。这个key可以是Tensorflow在内部定义的一些key,也可以是用户自定义的名字(string)。
- Tensorflow中内部定义的许多标准key,全部定义在了tf.GraphKeys这个类中(例:tf.graphKeys.TRAINBLE_VARIABLE、tf.GraphKey.GLOBAL_VARIABLES等)。
tf.trainable.variables() <==> tf.get_collection(tf.GraphKey.TRAINABLE_VARIABLES)
tf.global_variables() <==> tf.get_colleciontion(tf.GraphKey.GLOBAL_VARIABLES)
- 对于用户定义的Key:
对于这一段对训练过程定义的代码,用户希望特别关注pred、loss、train_op这几个操作,那么就可以使用如下代码,将这几个变量加入collection中。并命名"training_collection":pred = model_network(X) loss = tf.reduce_mean(..., pred, ...) train_op = tf.train.AdamOptimizer(learingrate).minimize(loss)
然后通过tf.add_to_collection("training_collection", pred) tf.add_to_collection("training_collection", loss) tf.add_to_collection("training_collection", train_op)
train_collect = tf.get_collection("training_collection")
得到python list,其内容就是pred、loss、train_op的Tensor。 这样操作一般是为了在一个新的session中打开这张图时,方便快速获取想要的张量。
例子:with tf.session() as sess: pred = model_network(X) loss = tf.reduce_mean(..., pred, ...) train_op = tf.train.AdamOptimizer(learingrate).minimize(loss) tf.add_to_collection("training_collection", train_op) tf.train.export_meta_graph(tf.get_default_graph(), "my_graph.meta")
#通过import_meta_graph将图恢复(同时初始化本session的default图), #并通过get_collection重新获取train_op,以及通过train_op来开始一段训练(sess.run()) with tf.Session() as sess1: tf.train.import_meta_graph("my_graph.meta") train_op = tf.get_collection("training_collection")[0] sess1.sun(train_op)
从MetaGraph中恢复构建的Graph是可以被训练的。
但。MetaGraph中虽然包含Variable的信息,但是没有Variable的实际值。所以从MetaGraph中恢复的Graph,训练都是从随机初始化的值开始的,训练中的Variable的实际值都保存在checkpoint文件中,如果要从之前训练装填恢复训练,就需要从checkpoint中restore。
.
tf.export_meta_graph() / tf.import_meta_graph()
:用来对MetaGraph读写的API。
tf.train.saver.save()
: 在保存checkpoint的同时也会保存MetaGraph。
tf.train.saver.restore()
: 在恢复图时,只恢复了Variable。所以需要加上tf.import_meta_graph()
来从MetaGraph中恢复Graph。一般的,训练时不需要熊MetaGraph中恢复图Graph,而是在python中构建的网络的Graph,并对其恢复Variable。
4 CheckPoint
checkpoint中全面保存了训练某时间截断的信息,包括参数、超参数、梯度等。
tf.train.saver.save() / tf.train.saver.restore()
:则能够完整的保存和恢复神经网络的训练。
checkpoint分为两个文件保存Variable的二进制信息:ckpt文件保存了Variable的二进制信息,index文件用于保存了ckpt文件中对应Variable的偏移量信息。
总结
tensorflow三种API所保存和恢复的Graph是不一样的
- GraphDef 的保存与加载的api:
tf.train.write_graph() / tf.Import_graph_def()
- MetaGraph读写的api :
tf.export_meta_graph() / tf.import_meta_graph()
- checkpoint读写的api :
tf.train.saver.save() / tf.train.saver.restore()
在Python中构建的Graph,使用Tensorflow运行时,是将该图序列化到 Protocol Buffer得到GraphDef,以方便在后端运行。在此过程中,图的保存、恢复、运行都通过ProtoBuf来实现的。 GraphDef、MetaGraph、Variable、Collection、Saver等都有对应的ProtoBuf定义。例:用户只能找到Node的前一个Node,却无法知道自己的输出会被那个Node接受。
#本博客第一个代码运行结果
node {
name: "Bdata1"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
unknown_rank: true
}
}
}
}
node {
name: "Bdata2"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
unknown_rank: true
}
}
}
}
node {
name: "multiply"
op: "Mul"
input: "Bdata1"
input: "Bdata2"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
versions {
producer: 38
}