回调函数(callback)是在调用fit 时传入模型的一个对象(即实现特定方法的类实例),它在训练过程中的不同时间点都会被模型调用。它可以访问关于模型状态与性能的所有可用数据,还可以采取行动:中断训练、保存模型、加载一组不同的权重或改变模型的状态。
keras中也提供了丰富的回调API,我们可以根据需求自定义相关的对象。
文章目录
BaseLogger
顾名思义,基础日志,用于记录每个epoch的平均metrics.
该回调函数在每个模型中都会被自动调用
CSVLogger
记录每个epoch的结果到csv文件中
csvlogger = tf.keras.callbacks.CSVLogger(filename, separator=',', append=False)
参数 | 注解 |
---|---|
fiename | 保存的csv文件名,如run/log.csv |
separator | 字符串,csv分隔符 |
append | 默认为False,为True时csv文件如果存在则继续写入,为False时总是覆盖csv文件 |
结果如下:包括训练集和验证集合的loss,以及学习率
EarlyStopping
当metric停止提升时,停止训练
earlystoppoing = tf.keras.callbacks.EarlyStopping(
monitor='val_loss', min_delta=0, patience=0, verbose=0, mode='auto',
baseline=None, restore_best_weights=False)
参数 | 注解 |
---|---|
monitor | 监视标准 |
min_delta | 监控标准在训练过程中允许的最小改变量,即小于此值后便认为性能没有提升 |
patience | 相比上一个epoch训练监控标准没有提升(小于min_delta),则经过patience个epoch后停止训练(还是很形象的patience 耐心) |
verbose | 信息展示模式 |
mode | ‘auto’,‘min’,‘max’之一,与monitor对应。比如loss对应min,acc对应max;auto根据monitor的名称自动定义 |
baseline | 监控标准的基准线,当训练过程相对基准没有提升则停止训练 |
restore_best_weights | 布尔型,Ture,重载最好的监控量的epoch对应的weight,False,最后一步的权重 |
LearningRateScheduler
定义学习率日程表,即自定义不同epcoh的学习率
# 前十组为0.001 后面指数减少
def scheduler(epoch):
if epoch < 10:
return 0.001
else:
return 0.001 * tf.math.exp(0.1 * (10 - epoch))
learningreatescheduler = tf.keras.callbacks.LearningRateScheduler(scheduler)
ModelCheckpoint
按照定义以一定频率存储模型或者权值文件
modelcheckpoint=tf.keras.callbacks.ModelCheckpoint(
filepath, monitor='val_loss', verbose