探索一下BERT模型保存和加载方式,基于源码。
保存
所谓“保存模型”一般是指保存ckpt和saved_model两种格式的模型。
ckpt方式与session.run模型下保存模型格式一样(在sess.run模式下,通常使用saver = tf.train.Saver()和saver.save()保存模型),这种模型文件需要原始模型代码才能运行,一般用于训练中保存/加载权重。
saved_model格式是一种轻量化的模型,不仅包含权重值,还包含计算。它不需要原始模型构建代码就可以运行,因此,对共享和部署(使用 TFLite、TensorFlow.js、TensorFlow Serving 或 TensorFlow Hub)非常有用。
ckpt方式下一共会保存4个文件:
model.ckpt-xxxxx.data-00000-of-00001: 保存当前参数值。比如网络的权值,偏置,操作等等。
model.ckpt.index :保存当前参数名。二进制或者其他格式,不可直接查看 。
model.ckpt.meta:某个ckpt的meta数据 二进制 或者其他格式,不可直接查看,保存了TensorFlow计算图的结构信息。
checkpoint:文本文件,记录了保存的最新的checkpoint文件以及其它checkpoint文件列表。
1.默认checkpoint的保存行为
每10分钟(600 秒)写入一个checkpoint,还会在train方法开始(第一次迭代)和完成(最后一次迭代)时写入一个checkpoint;只在目录中保留5个最近写入的checkpoint;
2.修改默认checkpoint的保存行为
创建一个RunConfig对象来定义所需的时间安排;在实例化Estimator时,将该RunConfig对象传递给Estimator的config参数;
my_checkpointing_config = tf.estimator.RunConfig(
save_checkpoints_secs = 20*60, # 每20分钟保存一次checkpoint
keep_checkpoint_max = 10, # 保存10个最近的checkpoints
)
classifier = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
hidden_units=[10, 10],
n_classes=3,
model_dir='models/iris',
config=my_checkpointing_config)
加载
1.Estimator 通过运行 model_fn() 构建模型图。
2.Estimator 根据最近写入的检查点中存储的数据来初始化新模型的权重。
一旦存在检查点,TensorFlow 就会在您每次调用 train()、evaluate() 或 predict() 时加载模型。