背景
在未知训练迭代次数的情况下,动态确定每个训练需要的迭代次数,需要设置自定义回调函数,配置acc或mse损失到达一定程度后停止训练
实例
import tensorflow as tf
class earlyStop(tf.keras.callbacks.Callback):
def __init__(self,mode,acc_threshold=0.95,loss_threshold=0.025):
super().__init__()
self.mode = mode
self.acc_threshold = acc_threshold
self.loss_threshold = loss_threshold
def on_epoch_end(self, epoch, logs=None):
if self.mode=='a':#二分类
if float(logs['acc'][-1])>self.acc_threshold:
self.model.stop_training = True
print("训练Early Stopping 迭代次数共计",epoch)
elif self.mode=='b':#值回归
if float(logs['loss'][-1])<self.loss_threshold:
self.model.stop_training = True
print("训练Early Stopping 迭代次数共计",epoch)
关键点
1. 终止训练
self.model.stop_training = True
2. 回调函数调用
history = model.fit(features,
to_categorical(labels,2),
epochs=epochs, # 迭代次数
batch_size=batch_size,
verbose=2, # 该方法训练不动 异常占比过小
# validation_split=0.1 # 没必要,数据集过度不平衡,参考价值不大
callbacks=[earlyStop(mode,acc_threshold,loss_threshold)]
)
总结
对于迭代次数不定的场景、需要按照条件停止训练的场景,都可以自己编写回调函数,不用官方提供的earlystopping!自力更生,慢却也最快。