tensorflow2.0的模型保存加载的几个方法

tensorflow2.0中模型的加载更加便捷。
我在github上新建了一个有关ner的项目,其中有对tensorflow2.0的api的一些详细使用。NER
想了解更多tensorflow2.0中模型存储加载方法,可以直接到其官方网站tf2.0.

我们这里说一下几个保存权重的方法:
假如当前建立的模型代码如下:

import tensorflow as tf
from tensorflow import keras
def get_model():
  # Create a simple model.
  inputs = keras.Input(shape=(32,))
  outputs = keras.layers.Dense(1)(inputs)
  model = keras.Model(inputs, outputs)
  model.compile(optimizer='adam', loss='mean_squared_error')
  return model

model = get_model()
opt = tf.keras.optimizers.Adam(0.1)
checkpoint_dir="./checkpoint"

1、保存检查点
具体api如下:

tf.train.Checkpoint
tf.train.CheckpointManager

使用以上两个api就可以保存训练中所有的权重。具体操作如下:
首先创建检查点

ckpt = tf.train.Checkpoint(optimizer=opt,model=model)
manager = tf.train.CheckpointManager(ckpt,
                                   checkpoint_dir,
                                   max_to_keep=3)

具体参数含义可以直接help查看api中参数的解释。
创建完检查点后,如果存在旧模型,就需要从旧模型中恢复权重。操作如下:

ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
     print("Restored from {}".format(manager.latest_checkpoint))
else:
     print("Initializing from scratch.")

然后再看看如何保存模型
在训练过程中我们可以直接使用manager的功能save进行存储,相关代码如下:

for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))

2、keras内置存储功能
keras有两种可以存储载入模型的内置函数

  • save 与load_model函数对
    具体api:
model.save() 或者 tf.keras.models.save_model()
tf.keras.models.load_model()

保存模型:

model.save(checkpoint_dir)

默认存储结果会有三个:

assets  saved_model.pb  variables

也可以直接指定存储为HDF5的格式

model.save(checkpoint_dir + '/' + 'model.h5')

载入模型:

recover_model = keras.models.load_model(checkpoint_dir)

如果是HDF5格式文件:

recover_model = keras.models.load_model(checkpoint_dir+ '/' + 'model.h5')

载入模型后生成新的对象recover_model,会复制原来model的所有功能。后续的训练测试使用recover_model

  • keras.Model 内置的save_weights与load_weights、get_weights与set_weights。
    其中常用的是save_weights与load_weights。
    save_weights可以有两种存储方式,tensorflow格式与h5格式。默认为使用tensorflow方式也是类似于检查点的方式进行存储。
    具体操作如下:
model.save_weights(path=checkpoint_dir)

存储为h5

model.save_weights(path=checkpoint_dir+/model.h5’,save_format='h5')

载入方式:
非h5文件:

model.load_weights(path=checkpoint_dir)

h5文件的载入:

model.load_weights(path=checkpoint_dir+'/model.h5')

以上就是训练过程常用的模型存储与加载的方式。可以看到tf2.0简化了tf1.0中的许多操作,对于用户来说已经是非常友好。拥抱pytorch的同学们可以再回来继续当tfboys。但是模型的部署怎么搞?以上几种办法,都需要搭建原始的结构,然后载入权重。这和环境上部署毛都不沾。
接下来介绍一个保存环境部署的模型的方法。具体应用在我的ner项目中已经体现,具体文件是run_pb.py。只有简单的四行代码,就可以载入模型。

首先看看怎么保存模型:
可以先恢复检查点,但是忽略优化器之类的权重。这里参考我ner项目中的infer代码。可用重新建一个图,不包含任何优化器节点。
然后恢复模型:

ner = ner_model(config,training=False)
#从训练的检查点恢复权重
ckpt = tf.train.Checkpoint(ner=ner)
latest_ckpt = tf.train.latest_checkpoint(checkpoint_dir + 'trains')
#添加expect_partial()忽略优化器相关节点
status = ckpt.restore(latest_ckpt).expect_partial()

恢复之后,保存模型:

tf.saved_model.save(ner, checkpoint_dir + 'infers/')

经过这一步,我们可以看到,在checkpoint_dir + 'infers/'目录下有:

assets  saved_model.pb  variables

一个pb文件,两个检查点目录,这两个目录里面东西的作用,目前未知。

然后参看run_pb.py中代码。仅仅只有四行代码,我们就可以部署训练号的模型到环境上。

评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值