深度学习之Early Stopping:TensorFlow中的早停机制示例及说明

深度学习之Early Stopping:TensorFlow中的早停机制示例及说明

一、简单说明

  在模型训练中,早停机制是一种用于避免模型在训练时出现过拟合(过拟合的检测方法参见另一篇文章《深度学习之Over Fitting:神经网络过拟合的检测方法及应对策略》)的方法。它通过跟踪指定的参考指标来实现,当指标连续几个epoch不再提高时,训练就会停止,以避免模型过度拟合,提高模型的泛化能力,通过及时停止训练,也可以减少训练时间和计算资源的消耗。TensorFlow 中提供了EarlyStopping回调函数来实现早停功能。

二、代码说明

  实际使用中,我们在模型结构设计后,通过调用tf.keras.callbacks.EarlyStopping来设定提前停止条件,如下:

# 设定提前停止条件
es_callback = tf.keras.callbacks.EarlyStopping(monitor='loss',
                                               patience=3,
                                               start_from_epoch=0)

其中,monitor参数用来设定早停的参考指标,在本例中我们选择了训练损失loss,也可以选择训练准确性accuracy,验证损失val_loss,或验证准确性val_accuracypatience参数设定为3,在本例中则意味着如果连续三轮训练损失都没有提高(即损失变大),则停止训练;参数start_from_epoch,用于设定上述条件从第几轮开始起作用。
设定好上述提前停止条件后,还需要在训练时通过callbacks参数传递给fit方法,如下:

# 训练模型
history = model.fit(x_train,y_train,
                    epochs=300,
                    batch_size=20,
                    validation_data=(x_test,y_test),
                    callbacks=es_callback)

  如下图,是某次训练结束前几轮的训练数据。从图中可见,在连续三轮(epoch 55~epoch 57)的训练损失loss都没有提高,即均大于epoch 54的训练损失loss,满足了设定的早停条件,训练终止。在这里插入图片描述

三、完整代码

#!/user/bin/env python3
# -*- coding : utf-8 -*-
import tensorflow as tf

def main():
    # 加载MNIST数据库
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0
    
    # 设计网络结构
    model = tf.keras.models.Sequential([
            tf.keras.layers.Flatten(input_shape=(28,28)),# 因为输入数据是图片(28*28),所以此处需要加一个Flatten层将图片数据展平
            tf.keras.layers.Dense(100),
            tf.keras.layers.Dense(10,activation="sigmoid")
            ],name="FNN")
    
    # 定义代价函数
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    
    # 设定模型训练配置
    model.compile(optimizer='adam',loss=loss_fn,metrics=['accuracy'])

    # 设定提前停止条件
    es_callback = tf.keras.callbacks.EarlyStopping(monitor='loss',
                                                   patience=3,
                                                   start_from_epoch=0)

    # 训练模型
    history = model.fit(x_train,y_train,
                        epochs=300,
                        batch_size=20,
                        validation_data=(x_test,y_test),
                        callbacks=es_callback)
    
    # 保存模型
    # tf.keras.models.save_model(model,filepath='./model/ann')
    
if __name__=='__main__':
  main()

四、参考资料

[1] TensorFlow API 文档:早停回调函数

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值