tensorflow2.0中模型的加载更加便捷。
我在github上新建了一个有关ner的项目,其中有对tensorflow2.0的api的一些详细使用。NER
想了解更多tensorflow2.0中模型存储加载方法,可以直接到其官方网站tf2.0.
我们这里说一下几个保存权重的方法:
假如当前建立的模型代码如下:
import tensorflow as tf
from tensorflow import keras
def get_model():
# Create a simple model.
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = keras.Model(inputs, outputs)
model.compile(optimizer='adam', loss='mean_squared_error')
return model
model = get_model()
opt = tf.keras.optimizers.Adam(0.1)
checkpoint_dir="./checkpoint"
1、保存检查点
具体api如下:
tf.train.Checkpoint
tf.train.CheckpointManager
使用以上两个api就可以保存训练中所有的权重。具体操作如下:
首先创建检查点
ckpt = tf.train.Checkpoint(optimizer=opt,model=model)
manager = tf.train.CheckpointManager(ckpt,
checkpoint_dir,
max_to_keep=3)
具体参数含义可以直接help查看api中参数的解释。
创建完检查点后,如果存在旧模型,就需要从旧模型中恢复权重。操作如下:
ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
print("Restored from {}".format(manager.latest_checkpoint))
else:
print("Initializing from scratch.")
然后再看看如何保存模型
在训练过程中我们可以直接使用manager的功能save进行存储,相关代码如下:
for _ in range(50):
example = next(iterator)
loss = train_step(net, example, opt)
ckpt.step.assign_add(1)
if int(ckpt.step) % 10 == 0:
save_path = manager.save()
print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
2、keras内置存储功能
keras有两种可以存储载入模型的内置函数
- save 与load_model函数对
具体api:
model.save() 或者 tf.keras.models.save_model()
tf.keras.models.load_model()
保存模型:
model.save(checkpoint_dir)
默认存储结果会有三个:
assets saved_model.pb variables
也可以直接指定存储为HDF5的格式
model.save(checkpoint_dir + '/' + 'model.h5')
载入模型:
recover_model = keras.models.load_model(checkpoint_dir)
如果是HDF5格式文件:
recover_model = keras.models.load_model(checkpoint_dir+ '/' + 'model.h5')
载入模型后生成新的对象recover_model,会复制原来model的所有功能。后续的训练测试使用recover_model
- keras.Model 内置的save_weights与load_weights、get_weights与set_weights。
其中常用的是save_weights与load_weights。
save_weights可以有两种存储方式,tensorflow格式与h5格式。默认为使用tensorflow方式也是类似于检查点的方式进行存储。
具体操作如下:
model.save_weights(path=checkpoint_dir)
存储为h5
model.save_weights(path=checkpoint_dir+‘/model.h5’,save_format='h5')
载入方式:
非h5文件:
model.load_weights(path=checkpoint_dir)
h5文件的载入:
model.load_weights(path=checkpoint_dir+'/model.h5')
以上就是训练过程常用的模型存储与加载的方式。可以看到tf2.0简化了tf1.0中的许多操作,对于用户来说已经是非常友好。拥抱pytorch的同学们可以再回来继续当tfboys。但是模型的部署怎么搞?以上几种办法,都需要搭建原始的结构,然后载入权重。这和环境上部署毛都不沾。
接下来介绍一个保存环境部署的模型的方法。具体应用在我的ner项目中已经体现,具体文件是run_pb.py。只有简单的四行代码,就可以载入模型。
首先看看怎么保存模型:
可以先恢复检查点,但是忽略优化器之类的权重。这里参考我ner项目中的infer代码。可用重新建一个图,不包含任何优化器节点。
然后恢复模型:
ner = ner_model(config,training=False)
#从训练的检查点恢复权重
ckpt = tf.train.Checkpoint(ner=ner)
latest_ckpt = tf.train.latest_checkpoint(checkpoint_dir + 'trains')
#添加expect_partial()忽略优化器相关节点
status = ckpt.restore(latest_ckpt).expect_partial()
恢复之后,保存模型:
tf.saved_model.save(ner, checkpoint_dir + 'infers/')
经过这一步,我们可以看到,在checkpoint_dir + 'infers/'目录下有:
assets saved_model.pb variables
一个pb文件,两个检查点目录,这两个目录里面东西的作用,目前未知。
然后参看run_pb.py中代码。仅仅只有四行代码,我们就可以部署训练号的模型到环境上。