1.背景
在上一篇如何使用Tensorflow保存或者加载模型(一)文章中,站长介绍了怎么把Tensorflow模型的图和变量通过tf.train.Saver()
保存在本地。在这一篇文章中,站长会介绍用一种新的模型保存和加载的方式,ModelBuilder API
,在该方式下保存和加载模型会更加简单,而且支持Python和Java环境下运行,可以更好地满足工业界的需求。
1.1 模型文件介绍
ModelBuilder API会生成saved_model.pb
的文件和variables
的文件夹。
saved_model.pb 中的后缀pb代表protobuf,在Tensorflow中这个pb文件包含了模型图的定义和模型的权重,也是模型保存的核心文件。
variables 文件夹中包含的是变量的数据和索引文件。
1.2 模型的保存示例代码
我们这里仍然使用linear regression
模型进行演示,使用tf.saved_model.simple_save
进行模型保存。
# -*- coding: utf-8 -*-
import tensorflow as tf
import numpy as np
##1.创建PlaceHolder和初始化参数##
X = tf.placeholder("float", name="X")
Y = tf.placeholder("float", name="Y")
W = tf.Variable(np.random.randn(), name= "W")
b = tf.Variable(np.random.randn(), name= "b")
learning_rate = 0.02
epochs = 100
data_x = np.linspace(0, 50, 50)
data_y = np.linspace(0, 50, 50)
##2.实现梯度下降##
y_pred = tf.add(tf.multiply(X, W), b, name="y_pred")
loss = tf