深度学习之Early Stopping:TensorFlow中的早停机制示例及说明
一、简单说明
在模型训练中,早停机制是一种用于避免模型在训练时出现过拟合(过拟合的检测方法参见另一篇文章《深度学习之Over Fitting:神经网络过拟合的检测方法及应对策略》)的方法。它通过跟踪指定的参考指标来实现,当指标连续几个epoch不再提高时,训练就会停止,以避免模型过度拟合,提高模型的泛化能力,通过及时停止训练,也可以减少训练时间和计算资源的消耗。TensorFlow 中提供了EarlyStopping
回调函数来实现早停功能。
二、代码说明
实际使用中,我们在模型结构设计后,通过调用tf.keras.callbacks.EarlyStopping
来设定提前停止条件,如下:
# 设定提前停止条件
es_callback = tf.keras.callbacks.EarlyStopping(monitor='loss',
patience=3,
start_from_epoch=0)
其中,monitor
参数用来设定早停的参考指标,在本例中我们选择了训练损失loss
,也可以选择训练准确性accuracy
,验证损失val_loss
,或验证准确性val_accuracy
;patience
参数设定为3
,在本例中则意味着如果连续三轮训练损失都没有提高(即损失变大),则停止训练;参数start_from_epoch
,用于设定上述条件从第几轮开始起作用。
设定好上述提前停止条件后,还需要在训练时通过callbacks
参数传递给fit
方法,如下:
# 训练模型
history = model.fit(x_train,y_train,
epochs=300,
batch_size=20,
validation_data=(x_test,y_test),
callbacks=es_callback)
如下图,是某次训练结束前几轮的训练数据。从图中可见,在连续三轮(epoch 55~epoch 57)的训练损失loss都没有提高,即均大于epoch 54的训练损失loss,满足了设定的早停条件,训练终止。
三、完整代码
#!/user/bin/env python3
# -*- coding : utf-8 -*-
import tensorflow as tf
def main():
# 加载MNIST数据库
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# 设计网络结构
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28,28)),# 因为输入数据是图片(28*28),所以此处需要加一个Flatten层将图片数据展平
tf.keras.layers.Dense(100),
tf.keras.layers.Dense(10,activation="sigmoid")
],name="FNN")
# 定义代价函数
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
# 设定模型训练配置
model.compile(optimizer='adam',loss=loss_fn,metrics=['accuracy'])
# 设定提前停止条件
es_callback = tf.keras.callbacks.EarlyStopping(monitor='loss',
patience=3,
start_from_epoch=0)
# 训练模型
history = model.fit(x_train,y_train,
epochs=300,
batch_size=20,
validation_data=(x_test,y_test),
callbacks=es_callback)
# 保存模型
# tf.keras.models.save_model(model,filepath='./model/ann')
if __name__=='__main__':
main()