Tensorflow模型保存与调用

1、tf2.1在keras中model的保存与调用

       tf.keras.model类中的save_weights方法和load_weights方法保存模型的权重。

       tf.keras.model.save方法可保存整个模型。保存的模型包括:

  • The model architecture, allowing to re-instantiate the model.(模型的结构)
  • The model weights.(模型的权重)
  • The state of the optimizer, allowing to resume training exactly where you left off.(优化器的选择)

       tf.keras.model.save方法允许在一个文件中将模型的全部状态记录下来。tf.keras.models.load_model可将保存下来的模型重建,并且支持直接使用。Models built with the Sequential and Functional API can be saved to both the HDF5 and SavedModel formats.

from keras.models import load_model

model.save('my_model.h5')  # creates a HDF5 file 'my_model.h5'
del model  # deletes the existing model

# returns a compiled model
# identical to the previous one
model = load_model('my_model.h5')

HDF5文件数据格式参考文章:

1)http://docs.h5py.org/en/latest/index.html

2)https://blog.csdn.net/mzpmzk/article/details/89188968

 

2、什么是PB文件,保存为pb文件示例

      PB文件表示MetaGraph的protocal buffer格式的文件,MetaGraph包括计算图,数据流,以及相关的变量和输入输出signature以及asserts指创建计算图时额外的文件。

      谷歌推荐的保存模型的方式是保存模型为 PB 文件,它具有语言独立性,可独立运行,封闭的序列化格式,任何语言都可以解析它,它允许其他语言和深度学习框架读取、继续训练和迁移 TensorFlow 的模型。

示例代码如下

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import os
from tensorflow.python.framework import graph_util

pb_file_path = os.getcwd()

with tf.Session(graph=tf.Graph()) as sess:
    x = tf.placeholder(tf.int32, name='x')
    y = tf.placeholder(tf.int32, name='y')
    b = tf.Variable(1, name='b')
    xy = tf.multiply(x, y)
    # 这里的输出需要加上name属性
    op = tf.add(xy, b, name='op_to_store')

    sess.run(tf.global_variables_initializer())

    # convert_variables_to_constants 需要指定output_node_names,list(),可以多个
    constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store'])

    # 测试 OP
    feed_dict = {x: 10, y: 3}
    print(sess.run(op, feed_dict))

    # 写入序列化的 PB 文件
    with tf.gfile.FastGFile(pb_file_path+'model.pb', mode='wb') as f:
        f.write(constant_graph.SerializeToString())

    # 输出
    # INFO:tensorflow:Froze 1 variables.
    # Converted 1 variables to const ops.
    # 31

       TensorFlow
       在TensorFlow中,模型的持久化保存和加载主要通过Saver()。在初次训练之后调用如下的save函数保存,然后,在预测前,或者在继续训练前调用load加载参数即可。

def __init__():
    self.sess = tf.Session()
    # 定义好网络结构...
    self.sess.run(tf.global_variables_initializer())
def check_path(self, path):
    if not os.path.exists(path):
        os.mkdir(path)
def save(self):
    self.check_path('model')
    saver=tf.train.Saver(tf.global_variables(),max_to_keep=10)
    print("model: ",saver.save(self.sess,'model/modle.ckpt'))

def load(self):
    saver=tf.train.Saver(tf.global_variables())
    module_file = tf.train.latest_checkpoint('model')
    saver.restore(self.sess, module_file)


 

 

 

参考文章:

【1】https://zhuanlan.zhihu.com/p/32887066

【2】https://blog.csdn.net/sunshinezhihuo/article/details/79705445

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值