keras java 模型导出_tensorflow 模型导出总结

tensorflow 1.0 以及2.0 提供了多种不同的模型导出格式,例如说有checkpoint,SavedModel,Frozen GraphDef,Keras model(HDF5) 以及用于移动端,嵌入式的TFLite。 本文主要讲解了前4中导出格式,分别介绍了四种的导出的各种方式,以及加载,涉及了python以及java的实现。TFLite由于项目中没有涉及,之后会补充。

模型导出主要包含了:参数以及网络结构的导出,不同的导出格式可能是分别导出,或者是整合成一个独立的文件。参数和网络结构分开保存:checkpoint, SavedModel

只保存权重:HDF5(可选)

参数和网络结构保存在一个文件:Frozen GraphDef,HDF5(可选)

在tensorflow 1.0中,可以见下图,主要有三种主要的API,Keras,Estimator,以及Legacy即最初的session模型,其中tf.Keras主要保存为HDF5,Estimator保存为SavedModel,而Lagacy主要保存的是Checkpoint,并且可以通过freeze_graph,将模型变量冻结,得到Frozen GradhDef的文件。这三种格式的模型,都可以通过TFLite Converter导出为 .tflite 的模型文件,用于安卓/ios/嵌入式设备的serving。

在tensorflow 2.0中,推荐使用SavedModel进行模型的保存,所以keras默认导出格式是SavedModel,也可以通过显性使用 .h5 后缀,使得保存的模型格式为HDF5 。 此外其他low level API,都支持导出为SavedModel格式,以及Concrete Functions。Concrete Function是一个签名函数,有固定格式的输入和输出。 最终转化成Flatbuffer,服务端运行结束。

checkpint 的导出是网络结构和参数权重分开保存的。

其组成:

checkpoint # 列出该目录下,保存的所有的checkpoint列表,下面有具体的例子

events.out.tfevents.1583930869.prod-cloudserver-gpu169 # tensorboad可视化所需文件,可以直观看出模型的结构

'''model.ckpt-13000表示前缀,代表第13000 global steps时的保存结果,我们在指定checkpoint加载时,也只需要说明前缀即可。'''

model.ckpt-13000.index # 代表了参数名

model.ckpt-13000.data-00000-of-00001 # 代表了参数值

model.ckpt-13000.meta # 代表了网络结构

所以一个checkpoint 组成是由两个部分,三个文件组成,其中网络结构部分(meta文件),以及参数部分(参数名:index,参数值:data)

其中checkpoint文件中

model_checkpoint_path: "model.ckpt-16329"

all_model_checkpoint_paths: "model.ckpt-13000"

all_model_checkpoint_paths: "model.ckpt-14000"

all_model_checkpoint_paths: "model.ckpt-15000"

all_model_checkpoint_paths: "model.ckpt-16000"

all_model_checkpoint_paths: "model.ckpt-16329"

使用tensorboard --logdir PATH_TO_CHECKPOINT: tensorboard 会调用events.out.tfevents.*

文件,并生成tensorboard,例如下图

导出成CKPTtensorflow 1.0

# in tensorflow 1.0

saver = tf.train.Saver()

saver.save(sess=session, save_path=args.save_path)estimator

# estimator

"""通过 RunConfig 配置多少时间或者多少个steps 保存一次模型,默认600s 保存一次。具体参考 https://zhuanlan.zhihu.com/p/112062303"""

run_config = tf.estimator.RunConfig(

model_dir=FLAGS.output_dir, # 模型保存路径

session_config=config,

save_checkpoints_steps=FLAGS.save_checkpoints_steps, # 多少steps保存一次ckpt

keep_checkpoint_max=1)

estimator = tf.estimator.Estimator(

model_fn=model_fn,

config=run_config,

params=None

)关于estimator的介绍可以参考https://zhuanlan.zhihu.com/p/112062303​zhuanlan.zhihu.come9e9f03f706608d02c8160ce7209b0a6.png

加载CKPTtf1.0

ckpt加载的脚本如下,加载完后,session就会是保存的ckpt了。

# tf1.0

session = tf.Session()

session.run(tf.global_variables_initializer())

saver = tf.train.Saver()

saver.restore(sess=session, save_path=args.save_path) # 读取保存的模型对于estimator 会自动load output_dir 中的最新的ckpt。

我们常用的model_file = tf.train.latest_checkpoint(FLAGS.output_dir) 获取最新的ckpt

SavedModel

SavedModel 格式是tensorflow 2.0 推荐的格式,他很好地支持了tf-serving等部署,并且可以简单被python,java等调用。

一个 SavedModel 包含了一个完整的 TensorFlow program, 包含了 weights 以及 计算图 computation. 它不需要原本的模型代码就可以加载所以很容易在 TFLite, TensorFlow.js, TensorFlow Serving, or TensorFlow Hub 上部署。

通常SavedModel由以下几个部分组成

├── assets/ # 所需的外部文件,例如说初始化的词汇表文件,一般无

├── assets.extra/ # TensorFlow graph 不需要的文件, 例如说给用户知晓的如何使用SavedModel的信息. Tensorflow 不使用这个目录下的文件。

├── saved_model.pb # 保存的是MetaGraph的网络结构

├── variables # 参数权重,包含了所有模型的变量(tf.Variable objects)参数

├── variables.data-00000-of-00001

└── variables.index

导出为SavedModeltf 1.0 方式

"""tf1.0"""

x = tf.placeholder(tf.float32, [None, 784], name="myInput")

y = tf.nn.softmax(tf.matmul(x, W) + b, name="myOutput")

tf.saved_model.simple_save(

sess,

export_dir,

inputs={"myInput": x},

outputs={"myOutput": y})

simple_save 是对于普通的tf 模型导出的最简单的方式,只需要补充简单的必要参数,有很多参数被省略,其中最重要的参数是tag:

tag 是用来区别不同的 MetaGraphDef,这是在加载模型所需要的参数。其默认值是tag_constants.SERVING (“serve”).

对于某些节点,如果没有办法直接加name,那么可以采用 tf.identity, 为节点加名字,例如说CRF的输出,以及使用dataset后,无法直接加input的name,都可以采用这个方式:

def addNameToTensor(someTensor, theName):

return tf.identity(someTensor, name=theName)estimator 方式

"""estimator"""

def serving_input_fn():

label_ids = tf.placeholder(tf.int32, [None], name='label_ids')

input_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_id

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值