深度学习之Early Stopping:TensorFlow中的早停机制示例及说明

深度学习之Early Stopping:TensorFlow中的早停机制示例及说明

一、简单说明

  在模型训练中,早停机制是一种用于避免模型在训练时出现过拟合(过拟合的检测方法参见另一篇文章《深度学习之Over Fitting:神经网络过拟合的检测方法及应对策略》)的方法。它通过跟踪指定的参考指标来实现,当指标连续几个epoch不再提高时,训练就会停止,以避免模型过度拟合,提高模型的泛化能力,通过及时停止训练,也可以减少训练时间和计算资源的消耗。TensorFlow 中提供了EarlyStopping回调函数来实现早停功能。

二、代码说明

  实际使用中,我们在模型结构设计后,通过调用tf.keras.callbacks.EarlyStopping来设定提前停止条件,如下:

# 设定提前停止条件
es_callback = tf.keras.callbacks.EarlyStopping(monitor='loss',
                                               patience=3,
                                               start_from_epoch=0)

其中,monitor参数用来设定早停的参考指标,在本例中我们选择了训练损失loss,也可以选择训练准确性accuracy,验证损失val_loss,或验证准确性val_accuracypatience参数设定为3,在本例中则意味着如果连续三轮训练损失都没有提高(即损失变大),则停止训练;参数start_from_epoch,用于设定上述条件从第几轮开始起作用。
设定好上述提前停止条件后,还需要在训练时通过callbacks参数传递给fit方法,如下:

# 训练模型
history = model.fit(x_train,y_train,
                    epochs=300,
                    batch_size=20,
                    validation_data=(x_test,y_test),
                    callbacks=es_callback)

  如下图,是某次训练结束前几轮的训练数据。从图中可见,在连续三轮(epoch 55~epoch 57)的训练损失loss都没有提高,即均大于epoch 54的训练损失loss,满足了设定的早停条件,训练终止。在这里插入图片描述

三、完整代码

#!/user/bin/env python3
# -*- coding : utf-8 -*-
import tensorflow as tf

def main():
    # 加载MNIST数据库
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0
    
    # 设计网络结构
    model = tf.keras.models.Sequential([
            tf.keras.layers.Flatten(input_shape=(28,28)),# 因为输入数据是图片(28*28),所以此处需要加一个Flatten层将图片数据展平
            tf.keras.layers.Dense(100),
            tf.keras.layers.Dense(10,activation="sigmoid")
            ],name="FNN")
    
    # 定义代价函数
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    
    # 设定模型训练配置
    model.compile(optimizer='adam',loss=loss_fn,metrics=['accuracy'])

    # 设定提前停止条件
    es_callback = tf.keras.callbacks.EarlyStopping(monitor='loss',
                                                   patience=3,
                                                   start_from_epoch=0)

    # 训练模型
    history = model.fit(x_train,y_train,
                        epochs=300,
                        batch_size=20,
                        validation_data=(x_test,y_test),
                        callbacks=es_callback)
    
    # 保存模型
    # tf.keras.models.save_model(model,filepath='./model/ann')
    
if __name__=='__main__':
  main()

四、参考资料

[1] TensorFlow API 文档:早停回调函数

  • 5
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
以下是一个简单的利用深度学习进行图像识别的Python代码示例: ```python # 导入所需的库 import tensorflow as tf from tensorflow.keras.datasets import mnist from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Dense, Dropout, Flatten, Conv2D, MaxPooling2D from tensorflow.keras.callbacks import EarlyStopping # 加载 MNIST 数据集 (x_train, y_train), (x_test, y_test) = mnist.load_data() # 将像素值缩放到 0~1 之间 x_train = x_train / 255.0 x_test = x_test / 255.0 # 将标签转换为 one-hot 编码 y_train = tf.keras.utils.to_categorical(y_train, 10) y_test = tf.keras.utils.to_categorical(y_test, 10) # 构建模型 model = Sequential() model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1))) model.add(MaxPooling2D(pool_size=(2, 2))) model.add(Dropout(0.25)) model.add(Flatten()) model.add(Dense(128, activation='relu')) model.add(Dropout(0.5)) model.add(Dense(10, activation='softmax')) # 编译模型 model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) # 定义回调函数 early_stopping = EarlyStopping(monitor='val_loss', patience=3) # 训练模型 model.fit(x_train.reshape(-1, 28, 28, 1), y_train, batch_size=128, epochs=20, verbose=1, validation_split=0.2, callbacks=[early_stopping]) # 评估模型 score = model.evaluate(x_test.reshape(-1, 28, 28, 1), y_test, verbose=0) print('Test loss:', score[0]) print('Test accuracy:', score[1]) ``` 该代码使用 TensorFlow 库来创建一个卷积神经网络模型,该模型可以识别手写数字。数据集使用的是 MNIST 数据集,该数据集包含了一系列手写数字图像。在训练时,我们用回调函数来避免模型过拟合。在训练完成后,我们使用测试集对模型进行评估,并输出测试集的损失和准确率。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值