以xDeepFM为例
保存模型参数
from deepctr.models import xDeepFM
model = xDeepFM(...)
model.compile(...)
model.fit(...)
# save_weights
model.save_weights('... .h5')
读取h5模型
model = xDeepFM(...)
# load_weights
model.load_weights('... .h5')
由于h5文件保存的是模型参数,因此模型结构需要自己手动构建,如果需要保存模型结构,可以使用json文件保存模型结构。
保存最优模型权重
上述方法保存的是所有epoch训练后的最后一次结果,但不一定是最优值,因此可以使用checkpoint保存最优权重。
model = xDeepFM(...)
model.compile(...)
# checkpoint
checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath='... .h5', monitor='val_loss', verbose=1, save_best_only=True, mode = 'min')
callback_list = [checkpoint]
model.fit(...,callbacks = callback_list )
ModelCheckpoint的参数中,filepath即保存的h5文件,monitor是监控的指标,一般为val_loss,val_acc等,mode指定保存最大值还是最小值,具体解释可参考官方文档。