java导出mpp格式_tensorflow 模型导出总结

ae276255051598a23b9034c81c6cb794.png

25901b2e4176ba624d0b449240ae85ad.png
  • Checkpoints
    • 导出成CKPT
    • 加载CKPT
  • SavedModel
    • 导出为SavedModel
    • 加载SavedModel
      • Python 加载
      • JAVA 加载
      • CLI 加载
  • Frozen Graph
    • 导出为pb
      • python
      • CLI转换工具
    • 模型加载
      • Python 加载
      • Java 加载
  • HDF5
    • HDF5导出
    • HDF5加载
  • tfLite
    • TFlite转换
    • TFLite 加载
  • ref

tensorflow 1.0 以及2.0 提供了多种不同的模型导出格式,例如说有checkpoint,SavedModel,Frozen GraphDef,Keras model(HDF5) 以及用于移动端,嵌入式的TFLite。 本文主要讲解了前4中导出格式,分别介绍了四种的导出的各种方式,以及加载,涉及了python以及java的实现。TFLite由于项目中没有涉及,之后会补充。

模型导出主要包含了:参数以及网络结构的导出,不同的导出格式可能是分别导出,或者是整合成一个独立的文件。

  • 参数和网络结构分开保存:checkpoint, SavedModel
  • 只保存权重:HDF5(可选)
  • 参数和网络结构保存在一个文件:Frozen GraphDef,HDF5(可选)

在tensorflow 1.0中,可以见下图,主要有三种主要的API,Keras,Estimator,以及Legacy即最初的session模型,其中tf.Keras主要保存为HDF5,Estimator保存为SavedModel,而Lagacy主要保存的是Checkpoint,并且可以通过freeze_graph,将模型变量冻结,得到Frozen GradhDef的文件。这三种格式的模型,都可以通过TFLite Converter导出为 .tflite 的模型文件,用于安卓/ios/嵌入式设备的serving。

f0c7981fe93ca8bc2f3085e559ce2abe.png

在tensorflow 2.0中,推荐使用SavedModel进行模型的保存,所以keras默认导出格式是SavedModel,也可以通过显性使用 .h5 后缀,使得保存的模型格式为HDF5 。 此外其他low level API,都支持导出为SavedModel格式,以及Concrete Functions。Concrete Function是一个签名函数,有固定格式的输入和输出。 最终转化成Flatbuffer,服务端运行结束。

checkpint 的导出是网络结构和参数权重分开保存的。
其组成:

checkpoint # 列出该目录下,保存的所有的checkpoint列表,下面有具体的例子
events.out.tfevents.1583930869.prod-cloudserver-gpu169 # tensorboad可视化所需文件,可以直观看出模型的结构
'''
model.ckpt-13000表示前缀,代表第13000 global steps时的保存结果,我们在指定checkpoint加载时,也只需要说明前缀即可。
'''
model.ckpt-13000.index # 代表了参数名
model.ckpt-13000.data-00000-of-00001 # 代表了参数值
model.ckpt-13000.meta # 代表了网络结构

所以一个checkpoint 组成是由两个部分,三个文件组成,其中网络结构部分(meta文件),以及参数部分(参数名:index,参数值:data)

其中checkpoint文件中

model_checkpoint_path: "model.ckpt-16329"
all_model_checkpoint_paths: "model.ckpt-13000"
all_model_checkpoint_paths: "model.ckpt-14000"
all_model_checkpoint_paths: "model.ckpt-15000"
all_model_checkpoint_paths: "model.ckpt-16000"
all_model_checkpoint_paths: "model.ckpt-16329"

使用tensorboard --logdir PATH_TO_CHECKPOINT: tensorboard 会调用events.out.tfevents.*
文件,并生成tensorboard,例如下图

c67cece5acdb417fa66405a58751b93f.png

导出成CKPT

  • tensorflow 1.0
# in tensorflow 1.0
saver = tf.train.Saver()
saver.save(sess=session, save_path=args.save_path)
  • estimator
# estimator
"""
通过 RunConfig 配置多少时间或者多少个steps 保存一次模型,默认600s 保存一次。
具体参考 https://zhuanlan.zhihu.com/p/112062303
"""
run_config = tf.estimator.RunConfig(
    model_dir=FLAGS.output_dir, # 模型保存路径
    session_config=config,
    save_checkpoints_steps=FLAGS.save_checkpoints_steps, # 多少steps保存一次ckpt
    keep_checkpoint_max=1)
estimator = tf.estimator.Estimator(
  model_fn=model_fn,
  config=run_config,
  params=None
)
关于estimator的介绍可以参考
https://zhuanlan.zhihu.com/p/112062303​zhuanlan.zhihu.com
168c7a95f92477943afa56263537b270.png

加载CKPT

  • tf1.0
    ckpt加载的脚本如下,加载完后,session就会是保存的ckpt了。
# tf1.0
session = tf.Session()
session.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.restore(sess=session, save_path=args.save_path)  # 读取保存的模型
  • 对于estimator 会自动load output_dir 中的最新的ckpt。
  • 我们常用的model_file = tf.train.latest_checkpoint(FLAGS.output_dir) 获取最新的ckpt

SavedModel

SavedModel 格式是tensorflow 2.0 推荐的格式,他很好地支持了tf-serving等部署,并且可以简单被python,java等调用。

一个 SavedModel 包含了一个完整的 TensorFlow program, 包含了 weights 以及 计算图 computation. 它不需要原本的模型代码就可以加载所以很容易在 TFLite, TensorFlow.js, TensorFlow Serving, or TensorFlow Hub 上部署。

通常SavedModel由以下几个部分组成

├── assets/ # 所需的外部文件,例如说初始化的词汇表文件,一般无
├── assets.extra/ # TensorFlow graph 不需要的文件, 例如说给用户知晓的如何使用SavedModel的信息. Tensorflow 不使用这个目录下的文件。
├── saved_model.pb # 保存的是MetaGraph的网络结构
├── variables # 参数权重,包含了所有模型的变量(tf.Variable objects)参数
    ├── variables.data-00000-of-00001
    └── variables.index

导出为SavedModel

  • tf 1.0 方式
"""tf1.0"""
x = tf.placeholder(tf.float32, [None, 784], name="myInput")
y = tf.nn.softmax(tf.matmul(x, W) + b, name="myOutput")
tf.saved_model.simple_save(
                sess,
                export_dir,
                inputs={
    "myInput": x},
                outputs={
    "myOutput": y})

simple_save 是对于普通的tf 模型导出的最简单的方式,只需要补充简单的必要参数,有很多参数被省略,其中最重要的参数是tagtag 是用来区别不同的

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值