checkpoint_path = “training_1/cp.ckpt”
checkpoint_dir = os.path.dirname(checkpoint_path)
创建一个保存模型权重的回调
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
save_weights_only=True,
verbose=1)
使用新的回调训练模型
model.fit(train_images,
train_labels,
epochs=10,
validation_data=(test_images,test_labels),
callbacks=[cp_callback]) # 通过回调训练
下面是选用最优函数的参数
keras.callbacks.ModelCheckpoint(filepath,monitor=‘val_loss’,verbose=0,save_best_only=Ture, save_weights_only=False, mode=‘auto’, period=1)
参数说明:
filename:字符串,保存模型的路径
monitor:需要监视的值
verbose:信息展示模式,0或1(checkpoint的保存信息,类似Epoch 00001: saving model to …)
save_best_only:当设置为True时,监测值有改进时才会保存当前的模型( the latest best model according to the quantity monitored will not be overwritten)
mode:‘auto’,‘min’,‘max’之一,在save_best_only=True时决定性能最佳模型的评判准则,例如,当监测值为val_acc时,模式应为max,当监测值为val_loss时,模式应为min。在auto模式下,评价准则由被监测值的名字自动推断。
save_weights_only:若设置为True,则只保存模型权重,否则将保存整个模型(包括模型结构,配置信息等)
period:CheckPoint之间的间隔的epoch数
注:如果用的是save_weights_only=True,在用load_weights导入参数前,需要先跑一遍模型,不然会报错。
创建回调以在验证损失达到特定值后提前停止训练。
stop_early = tf.keras.callbacks.EarlyStopping(monitor=‘val_loss’, patience=5)
由于 callbacks=[cp_callback] 是支持多个列表,因此可以把这个也加入模型训练中的参数中