学习TensorFlow
之后,我们知道,在TensorFlow
中,主要是通过构建graph
的形式来构建整个框架的。
本文想要彻底高清楚的问题就是:TensorFlow
中真个graph
的构建流程:
- 用户创建的时候是怎么构建的,也就是 python api 提供了什么样的形式来构建用户看的见的 graph;
- python api 创建的 graph 如何转换成 c++ 底层可用于计算的graph;
- graph 最终执行在硬件设备上是以什么形式完成计算的;
我想,搞清楚这三个问题,基本上就掌握了深度学习框架,是怎么完成真个计算的。
图(graph)
是 TensorFlow
用于表达计算任务的一个核心概念。从前端(python)
描述神经网络的结构,到后端在多机和分布式系统上部署,到底层 Device(CPU、GPU、TPU)
上运行,都是基于图来完成。
在tensorflow官方文档中,Graph被定义为“一些Operation和Tensor的集合”。我们使用的python代码表达的计算,就会生成一张图。
第一个阶段:Python Graph
import tensorflow as tf
bg = tf.Graph() # 创建一个图bg
with bg.as_default(): # 将bg 图设为当默认图
Bdata1 = tf.placeholder(tf.float32, name="Bdata1")
Bdata2 = tf.placeholder(tf.float32, name="Bdata2")
Bdata3 = tf.multiply(Bdata1, Bdata2, name="multiply")
tf.summary.FileWriter("./test", Bgraph)
上述代码,可视化之后的图如下
上图中,每一个圈表示一个算子(OP),椭圆到椭圆的边为张量(tensor),箭头的指向表示了这张图的Op的输入输出的tensor
的传递关系。
Python
代码中所描述的 Graph
(简称 Python Graph
),包含各个OP
和 tensor
。构建好后,tensorflow
运行时,将session
启动,真实的数据计算会被放在多CPU、GPU、ARM
等完成,并不是始终不变的东西。单纯的使用Python
的是无法有效的完成计算的。所以计算过程为:tensorflow
先将Python
代码描绘的图进行转换,转换为Protocol Buffer
(即序列化),在通过C/C++/CUDA
运行Protocol Buffer
所定义的图。
第二个阶段:Protocol Buffer
在第一阶段,用户使用Python
构建了自己的模型的计算图,到了第二个阶段,要执行运算的时候,Python ``Graph
会被序列化为Protocol Buffer
,称之为Graph Def
。
GraphDef
由许多的叫做 NodeDef
的Protocol Buffer
组成,NodeDef
与Python Graph
中的 Operation
相对(也就是 tf.Operation
的序列化 ProtoBuf
是NodeDef
)。
GraphDef
的保存与加载的api
为 tf.train.write_graph()/ tf.Import_graph_def()
import tensorflow as tf
bg = tf.Graph() # 创建一个图bg
with Bgraph.as_default(): # 将bg图设为当默认图
Bdata1 = tf.placeholder(tf.float32, name="Bdata1")
Bdata2 = tf.placeholder(tf.float32, name="Bdata2")
Bdata3 = tf.multiply(Bdata1, Bdata2, name="multiply")
with tf.Session(graph=bg) as sess:
tf.train.write_graph(sess.graph_def, "./", "test.pb", False)
print(sess.graph_def)
代码中的GraphDef
实际包含了:3个NodeDef
。NodeDef
中,包含name, op, input, attr
。其中input
是不同的node
之间的连接信息。
GraphDef
中不会保存Variable
的信息,但可以保存Constant
。所以是将Python Graph
中的Variable
替换为constant
存储在GraphDef
中(训练好的weight(variable)
得以保存)。所以从graphdef
来恢复的图和权重,没有Variable
,只能用于实际上的inference
,无法用于训练。
上面代码的运行结果
#上面代码运行结果
node {
name: "Bdata1"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
unknown_rank: true
}
}
}
}
node {
name: "Bdata2"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
unknown_rank: true
}
}
}
}
node {
name: "multiply"
op: "Mul"
input: "Bdata1"
input: "Bdata2"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
versions {
producer: 38
}
上述代码,使用api tf.train.write_graph()
将python graph
转换成了graphdef
。我么从结果来看,是没有实际的参数的。那么怎样才能将参数 Variables
拿到呢?
第三阶段:Meta Graph
在GraphDef
中无法得到Variables
,但通过MetaGraph
可以得到。
MetaGraph
的官方解释:一个MetaGraph
由一个计算图和其相关的元数据构成。其中包含了4中主要信息:
MetaInfoDef
:存放了一些元信息,例如版本和其他用户信息GraphDef
:MetaGraph
的核心内容之一SaverDef
:图的Saver
信息(例:最多同时保存的checkpoint
数量、需保存的tensor
的名字(并不保存tensor
中的实际内容))CollectionDef
:任何需要特殊注意的Python
对象,需要特殊的标注以方便import_meta_graph
后取回(例:"train_op"、"prediction"
等)
集合collection
是为了方便用户对图中的变量进行管理而被创建的一个概念,通过一个string
类型的key
来对一组python
对象进行命名的集合。这个key
可以是Tensorflow
在内部定义的一些key
,也可以是用户自定义的名字(string)
。
Tensorflow
中内部定义的许多标准key
,全部定义在了tf.GraphKeys
这个类中(例:tf.graphKeys.TRAINBLE_VARIABLE
、tf.GraphKey.GLOBAL_VARIABLES
等)。
以下写法存在等价关系:
tf.trainable.variables() <==> tf.get_collection(tf.GraphKey.TRAINABLE_VARIABLES)
tf.global_variables() <==> tf.get_colleciontion(tf.GraphKey.GLOBAL_VARIABLES)
用户可以自定义 key 集合:
pred = model_network(X)
loss = tf.reduce_mean(..., pred, ...)
train_op = tf.train.AdamOptimizer(learingrate).minimize(loss)
'''
对于这一段对训练过程定义的代码,用户希望特别关注pred、loss、train_op这几个操作,
那么就可以使用如下代码,将这几个变量加入collection中。并命名"training_collection":
'''
tf.add_to_collection("training_collection", pred)
tf.add_to_collection("training_collection", loss)
tf.add_to_collection("training_collection", train_op)
然后通过train_collect = tf.get_collection("training_collection")
得到python list
,其内容就是pred、loss、train_op
的Tensor
。 这样操作一般是为了在一个新的session
中打开这张图时,方便快速获取想要的张量。
with tf.session() as sess:
pred = model_network(X)
loss = tf.reduce_mean(..., pred, ...)
train_op = tf.train.AdamOptimizer(learingrate).minimize(loss)
tf.add_to_collection("training_collection", train_op)
tf.train.export_meta_graph(tf.get_default_graph(), "my_graph.meta")
#通过import_meta_graph将图恢复(同时初始化本session的default图),
#并通过get_collection重新获取train_op,以及通过train_op来开始一段训练(sess.run())
with tf.Session() as sess1:
tf.train.import_meta_graph("my_graph.meta")
train_op = tf.get_collection("training_collection")[0]
sess1.sun(train_op)
注意
从MetaGraph
中恢复构建的Graph
是可以被训练的。但,MetaGraph
中虽然包含Variable
的信息,但是没有Variable
的实际值。所以从MetaGraph
中恢复的Graph
,训练都是从随机初始化的值开始的,训练中的Variable
的实际值都保存在checkpoint
文件中,如果要从之前训练装填恢复训练,就需要从checkpoint
中restore
。
tf.export_meta_graph() / tf.import_meta_graph()
用来对MetaGraph
读写的API
。
tf.train.saver.save()
在保存checkpoint
的同时也会保存MetaGraph
。
tf.train.saver.restore()
在恢复图时,只恢复了Variable
。所以需要加上tf.import_meta_graph()
来从MetaGraph
中恢复Graph
。一般的,训练时不需要从MetaGraph
中恢复图Graph
,而是在python
中构建的网络的Graph
,并对其恢复Variable
。
checkpoint中全面保存了训练某时间截断的信息,包括参数、超参数、梯度等。
tf.train.saver.save() / tf.train.saver.restore():则能够完整的保存和恢复神经网络的训练。
checkpoint分为两个文件保存Variable的二进制信息:ckpt文件保存了Variable的二进制信息,index文件用于保存了ckpt文件中对应Variable的偏移量信息。
总结
Tensorflow
三种API
所保存和恢复的Graph
是不一样的
GraphDef
的保存与加载的api: tf.train.write_graph() / tf.Import_graph_def()
MetaGraph
读写的api :tf.export_meta_graph() / tf.import_meta_graph()
checkpoint
读写的api :tf.train.saver.save() / tf.train.saver.restore()
总觉的还是没有分析的很清楚,等继续深入研究学习,之后争取将来龙去脉讲清楚。
声明
本博客是个人学习时的一些笔记摘录和感想,不保证是为原创,内容汇集了网上相关资料和书记内容,在这之中也必有疏漏未加标注者,如有侵权请与博主联系。