简介
在上一节中,我们学习了基本的tensroflow的操作流程,特别值得注意的是,会话在2.0中被替换成了function。但是值得注意的是,目前我们甚至还不会如何保存和加载已经训练好的模型,因此本次我们主要学习如何保存训练好的模型以及如何加载它。而方法则分为自动保存和手动保存两种。
模型保存的两种基本格式
在tensorflow保存模型的时候,分为两种形式,一种是只保存权重,它的保存形式是.ckpt。另一种则是保存整个训练的模型,它的保存形式是.h5。
第一种方法:在训练的时候自动保存参数
因为我们的模型继承自父类model,而在model类当中,我们有一个一个叫做fit的方法,这个方法可以Trains the model for a fixed number of epochs。在这个方法中我们可以传入一个叫做callbacks的对象。而这个传入的对象就会帮助我们进行自动参数的保存。下面我会详细介绍Modelcheckpoint这个类。
Class ModelCheckpoint
全称: tf.keras.callbacks.ModelCheckpoint
属性
- filepath:字符串,存储模型文件的路径
- monitor:我们所监视的变量(quantity to monitor)
- verbose:verbosity mode, 0 or 1
- save_best_only:if save_best_only=True, the latest best model according to the quantity monitored will not be overwritten.
- mode:one of {auto, min, max}.当参数save_best_only为真时,我们必须在min和max之中选择一个。loss对应min,accuracy对应max。当我们选择auto的时候,会根据我们变量的命名自动决定。
- save_weights_only:为真的时候,只存储权重,反之则保存全图。
- save_freq:字符串’epoch’或者整数值。当选择’epoch’的时候,我们每一代都会保存模型。默认是’epoch’。
- **kwargs: Additional arguments for backwards compatibility. Possible key is period.
tips:i f filepath is weights.{epoch:02d}-{val_loss:.2f}.hdf5, then the model checkpoints will be saved with the epoch number and the validation loss in the filename.
默认初始化如下:
__init__(
filepath,
monitor='val_loss',
verbose=0,
save_best_only=False,
save_weights_only=False,
mode='auto'