TensorFlow模型的保存与加载(二)——pb模式【源码】

如果本文对您有帮助,欢迎点赞支持!

目录

前言

1、TF模型保存方法

2、pb模式

3、适合保存模型的时机

一、保存模型

1、定义简单网络模型

2、保存网络模型为pb文件

二、加载网络模型


前言

1、TF模型保存方法

网络模型的保存和重载操作是学习和训练AI模型的必备技能之一,也是进一步学习迁移学习知识的基础。Tensorflow的模型保存加载有不同格式,使用方法也不一样。目前来看,Tensorflow的保存方式按照生成的主要文件的格式基本可以分为三种:(1)checkpoint模式;(2) pb模式;(3)saved_model模式。

2、pb模式

通常我们使用 TensorFlow时保存模型都使用 ckpt 格式的模型文件,但是这种方式依赖 TensorFlow 的,只能在其框架下使用。

谷歌推荐保存模型为 PB 文件,这种方式具有以下优点:

(1)PB文件是一种可以跨语言,跨运行环境的序列化格式,几乎任何主流编程语言都可以解析它,它也允许其它深度学习框架读取、继续训练和迁移 TensorFlow 的模型。

(2)PB 文件会将模型的变量都会变成常量来使模型的大小大大减小,适合在手机端运行。

3、适合保存模型的时机

尽可能多地保存模型能帮助我们不错过效果最好的模型,但是实际操作也要考虑内存大小和运行效率。

工程师的通用做法是每训练多少步后就在验证集上计算一次准确率,如果本次结果比上次好则保存新的模型。最常见的做法是直接每训练多少步就直接保存一次模型,实际上TensorFlow的API也提供了在模型文件名提供添加训练step后缀的方式。

一、保存模型

在上述内容我们只是介绍了方法步骤,而且路径的输入上也存在一些技巧,所以我们实际使用时要将上述步骤封装起来,下面我们将上述方法封装起来并给出测试代码和运行效果:

1、定义简单网络模型

为了让读者将精力集中到模型的保存和加载上,这里我们使用一个非常简单的计算图结构x*y+b:

# 定义计算图结构 x*y+b
x = tf.placeholder(tf.int32, name='x_input')
y = tf.placeholder(tf.int32, name='y_input')
b = tf.Variable(1, name='b')
xy = tf.multiply(x, y)
output = tf.add(xy, b, name='output')

运行上述计算图,相关测试代码如下:

# 输出计算图结果
sess = tf.Session()
sess.run(tf.global_variables_initializer())
y_pred = sess.run(output, {x: 10, y: 3})
print(y_pred) # 输出31

最后输出结果是10*3+1=31,和我们预想的一致。

2、保存网络模型为pb文件

接下来我们使用如下方法保存该计算图模型:

save_pb_model(sess,'./models/',['output'])

save_pb_model()是我们封装的方法,其相关代码如下:

def save_pb_model(sess, save_path, output_nodes):
    # 将变量转换为常量,第1个参数是会话,第2个参数计算图的graph_def对象;第3个参数是结果输出节点的名称
    output_graph_def = graph_util.convert_variables_to_constants(sess, tf.get_default_graph().as_graph_def(), output_nodes)
    # 检查是否存在路径
    path = os.path.abspath(save_path)  # 获取绝对路径
    if os.path.exists(path) is False:
        os.makedirs(path)
        print("成功创建模型保存新路径:{}".format(path))
    # 将计算图写入序列化的pb文件,第1个参数是保存路径及其文件名
    with tf.gfile.FastGFile(save_path + "model.pb", mode="wb") as f:
        f.write(output_graph_def.SerializeToString())
        print("成功使用PB模式保存模型到路径:{}".format(path))

成功运行后,会生成如下文件:

注意,运行代码中,代码编辑器可能不会马上更新文件目录,建议打开目录查看下是否生成。

二、加载网络模型

现在我们要将上述保存的计算图结构x*y+b从文件中重载出来,下面直接给出加载代码:

# 加载计算图结构 x*y+b
sess = tf.Session()
load_pb_model(sess, './models/')
# 获取计算图节点
graph = tf.get_default_graph()  # 获取计算图
x = graph.get_tensor_by_name("x_input:0")
y = graph.get_tensor_by_name("y_input:0")
output= graph.get_tensor_by_name("output:0")
# 运行计算图计算结果
y_pred = sess.run(output, {x: 10, y: 3})
print(y_pred) # 输出 31

其中加载方法load_pb_model()代码如下:

def load_pb_model(sess, save_path):
    with tf.gfile.FastGFile(save_path + 'model.pb', 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        sess.graph.as_default()
        tf.import_graph_def(graph_def, name='')  # 导入计算图

 

  • 16
    点赞
  • 33
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

魔法攻城狮MRL

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值