最近接触了tensorflow的object detection API发现里面读取的预先训练模型都是pb格式。
谷歌推荐的保存模型的方式是保存模型为 PB 文件,它具有语言独立性,可独立运行,封闭的序列化格式,任何语言都可以解析它,它允许其他语言和深度学习框架读取、继续训练和迁移 TensorFlow 的模型。
它的主要使用场景是实现创建模型与使用模型的解耦, 使得前向推导 inference的代码统一。
另外的好处是保存为 PB 文件时候,模型的变量都会变成固定的,导致模型的大小会大大减小,适合在手机端运行。
还有一个就是,真正离线测试使用的时候,pb格式的数据能够保证数据不会更新变动,就是不会进行反馈调节啦。
保存 PB 文件的代码:
-
import
tensorflow
as
tf
-
import
os
-
from
tensorflow.python.framework
import
graph_util
-
-
pb_file_path
=
os
.
getcwd
()
-
-
with
tf
.
Session
(
graph
=
tf
.
Graph
())
as
sess
:
-
x
=
tf
.
placeholder
(
tf
.
int32
,
name
=
'x'
)
-
y
=
tf
.
placeholder
(
tf
.
int32
,
name
=
'y'
)
-
b
=
tf
.
Variable
(
1
,
name
=
'b'
)
-
xy
=
tf
.
multiply
(
x
,
y
)
-
# 这里的输出需要加上name属性
-
op
=
tf
.
add
(
xy
,
b
,
name
=
'op_to_store'
)
-
-
sess
.
run
(
tf
.
global_variables_initializer
())
-
-
# convert_variables_to_constants 需要指定output_node_names,list(),可以多个
-
constant_graph
=
graph_util
.
convert_variables_to_constants
(
sess
,
sess
.
graph_def
,
[
'op_to_store'
])
-
-
# 测试 OP
-
feed_dict
=
{
x
:
10
,
y
:
3
}
-
print
(
sess
.
run
(
op
,
feed_dict
))
-
-
# 写入序列化的 PB 文件
-
with
tf
.
gfile
.
FastGFile
(
pb_file_path
+
'model.pb'
,
mode
=
'wb'
)
as
f
:
-
f
.
write
(
constant_graph
.
SerializeToString
())
-
-
# 输出
-
# INFO:tensorflow:Froze 1 variables.
-
# Converted 1 variables to const ops.
-
# 31
加载 PB 模型文件典型代码:
-
from
tensorflow.python.platform
import
gfile
-
-
sess
=
tf
.
Session
()
-
with
gfile
.
FastGFile
(
pb_file_path
+
'model.pb'
,
'rb'
)
as
f
:
-
graph_def
=
tf
.
GraphDef
()
-
graph_def
.
ParseFromString
(
f
.
read
())
-
sess
.
graph
.
as_default
()
-
tf
.
import_graph_def
(
graph_def
,
name
=
''
)
# 导入计算图
-
-
# 需要有一个初始化的过程
-
sess
.
run
(
tf
.
global_variables_initializer
())
-
-
# 需要先复原变量
-
print
(
sess
.
run
(
'b:0'
))
-
# 1
-
-
# 输入
-
input_x
=
sess
.
graph
.
get_tensor_by_name
(
'x:0'
)
-
input_y
=
sess
.
graph
.
get_tensor_by_name
(
'y:0'
)
-
-
op
=
sess
.
graph
.
get_tensor_by_name
(
'op_to_store:0'
)
-
-
ret
=
sess
.
run
(
op
,
feed_dict
=
{
input_x
:
5
,
input_y
:
5
})
-
print
(
ret
)
-
# 输出 26
保存为 save model 格式也可以生成模型的 PB 文件,并且更加简单。
保存好以后到saved_model_dir目录下,会有一个saved_model.pb文件以及variables文件夹。顾名思义,variables保存所有变量,saved_model.pb用于保存模型结构等信息。
-
import
tensorflow
as
tf
-
import
os
-
from
tensorflow.python.framework
import
graph_util
-
-
pb_file_path
=
os
.
getcwd
()
-
-
with
tf
.
Session
(
graph
=
tf
.
Graph
())
as
sess
:
-
x
=
tf
.
placeholder
(
tf
.
int32
,
name
=
'x'
)
-
y
=
tf
.
placeholder
(
tf
.
int32
,
name
=
'y'
)
-
b
=
tf
.
Variable
(
1
,
name
=
'b'
)
-
xy
=
tf
.
multiply
(
x
,
y
)
-
# 这里的输出需要加上name属性
-
op
=
tf
.
add
(
xy
,
b
,
name
=
'op_to_store'
)
-
-
sess
.
run
(
tf
.
global_variables_initializer
())
-
-
# convert_variables_to_constants 需要指定output_node_names,list(),可以多个
-
constant_graph
=
graph_util
.
convert_variables_to_constants
(
sess
,
sess
.
graph_def
,
[
'op_to_store'
])
-
-
# 测试 OP
-
feed_dict
=
{
x
:
10
,
y
:
3
}
-
print
(
sess
.
run
(
op
,
feed_dict
))
-
-
# 写入序列化的 PB 文件
-
with
tf
.
gfile
.
FastGFile
(
pb_file_path
+
'model.pb'
,
mode
=
'wb'
)
as
f
:
-
f
.
write
(
constant_graph
.
SerializeToString
())
-
-
# INFO:tensorflow:Froze 1 variables.
-
# Converted 1 variables to const ops.
-
# 31
-
-
-
# 官网有误,写成了 saved_model_builder
-
builder
=
tf
.
saved_model
.
builder
.
SavedModelBuilder
(
pb_file_path
+
'savemodel'
)
-
# 构造模型保存的内容,指定要保存的 session,特定的 tag,
-
# 输入输出信息字典,额外的信息
-
builder
.
add_meta_graph_and_variables
(
sess
,
-
[
'cpu_server_1'
])
-
-
-
# 添加第二个 MetaGraphDef
-
#with tf.Session(graph=tf.Graph()) as sess:
-
# ...
-
# builder.add_meta_graph([tag_constants.SERVING])
-
#...
-
-
builder
.
save
()
# 保存 PB 模型
这种方法对应的导入模型的方法:
-
with
tf
.
Session
(
graph
=
tf
.
Graph
())
as
sess
:
-
tf
.
saved_model
.
loader
.
load
(
sess
,
[
'cpu_1'
],
pb_file_path
+
'savemodel'
)
-
sess
.
run
(
tf
.
global_variables_initializer
())
-
-
input_x
=
sess
.
graph
.
get_tensor_by_name
(
'x:0'
)
-
input_y
=
sess
.
graph
.
get_tensor_by_name
(
'y:0'
)
-
-
op
=
sess
.
graph
.
get_tensor_by_name
(
'op_to_store:0'
)
-
-
ret
=
sess
.
run
(
op
,
feed_dict
=
{
input_x
:
5
,
input_y
:
5
})
-
print
(
ret
)
-
# 只需要指定要恢复模型的 session,模型的 tag,模型的保存路径即可,使用起来更加简单
这样和之前的导入 PB 模型一样,也是要知道tensor的name。那么如何可以在不知道tensor name的情况下使用呢,实现彻底的解耦呢? 给add_meta_graph_and_variables
方法传入第三个参数,signature_def_map
即可。
参考:
https://zhuanlan.zhihu.com/p/32887066