T5:识别运动鞋

>- **🍨 本文为[🔗365天深度学习训练营](https://mp.weixin.qq.com/s/0dvHCaOoFnW8SCp3JpzKxg) 中的学习记录博客**
>- **🍖 原作者:[K同学啊](https://mtyjkh.blog.csdn.net/)**

一、模型跑通

1.1环境配置

编译器:pycharm community

语言环境:Python 3.9.19

深度学习环境:TensorFlow 2.9.0

1.2流程

  1. 导入必要的库

    • 导入Python标准库和第三方库,以便进行数据处理、模型构建和训练等操作。
  2. 设置环境变量

    • 设置环境变量以避免OpenMP错误,这是为了解决某些系统上的并发编程问题。
  3. GPU配置

    • 检查系统中是否有可用的GPU。如果有,配置GPU内存按需增长并设置可见设备;如果没有,输出提示使用CPU。
  4. 加载数据集

    • 设置数据集的本地目录路径,并统计目录下所有图像文件的总数。
  5. 显示示例图像

    • 加载并显示一个示例图像,以确保图像路径正确。
  6. 设置图像处理参数

    • 设置图像批处理大小(batch_size)和图像尺寸(img_height, img_width)。
  7. 加载训练和验证数据集

    • 从指定目录中加载训练和验证数据集,并将图像调整为设定的尺寸,同时设置批处理大小。
  8. 打印类别名称

    • 获取并打印数据集中所有类别的名称。
  9. 可视化训练图像

    • 可视化一批训练图像,并标注类别名称,以确保数据加载和标签正确。
  10. 打印数据批次形状

    • 打印一批图像和标签的形状,确认数据集的批处理大小和图像尺寸。
  11. 优化数据集加载速度

    • 使用缓存、打乱和预取操作来提高数据加载的性能。
  12. 定义数据增强方法

    • 定义数据增强方法,包括随机水平和垂直翻转、随机旋转和随机缩放,以增加训练数据的多样性。
  13. 构建模型

    • 构建卷积神经网络模型,包括数据增强层、归一化层、卷积层、池化层、全连接层和Dropout层。
  14. 显式构建模型

    • 显式构建模型以避免未构建错误,并打印模型的摘要以查看网络结构和参数数量。
  15. 设置学习率调度器和优化器

    • 定义学习率调度器,设置初始学习率,并采用指数衰减的方法来逐步降低学习率。将学习率调度器传入优化器。
  16. 编译模型

    • 编译模型,指定优化器、损失函数和评估指标。
  17. 定义回调函数

    • 定义检查点回调函数,以在验证准确率最佳时保存模型权重。定义早停回调函数,当验证准确率不再提升时提前停止训练,并恢复最佳权重。
  18. 训练模型

    • 使用训练数据和验证数据训练模型,设置训练轮数,并传入回调函数。
  19. 可视化训练和验证过程

    • 绘制训练和验证过程中的准确率和损失曲线,展示模型性能随训练轮数的变化情况。
  20. 总结和分析

    • 通过可视化结果和模型性能评估,总结训练过程中的表现和模型的最终效果。

1.3完整代码

import os
import tensorflow as tf
import pathlib
import PIL.Image
import matplotlib.pyplot as plt
from tensorflow.keras import layers, models
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping

# 导入必要的库,其中包括操作系统接口(os)、TensorFlow、路径处理库(pathlib)、图像处理库(PIL)、绘图库(matplotlib)以及Keras中用于构建和训练模型的各个模块。

# 设置环境变量以避免OpenMP错误
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
# 这行代码用于设置环境变量以避免在某些系统上可能出现的OpenMP错误。

# GPU配置
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    gpu0 = gpus[0]
    tf.config.experimental.set_memory_growth(gpu0, True)  # 设置显存按需增长
    tf.config.set_visible_devices([gpu0], "GPU")
    print("GPU available.")
else:
    print("GPU cannot be found, using CPU instead.")
# 检查系统中是否有可用的GPU,如果有,设置GPU内存按需增长,并将其设为可见设备;如果没有,输出提示使用CPU。

data_dir = "D:/others/pycharm/pythonProject/T5"
data_dir = pathlib.Path(data_dir)
image_count = len(list(data_dir.glob('*/*/*.jpg')))
print("图片总数为:", image_count)
# 设置数据集的目录路径,并统计该目录下所有jpg格式图像的总数。

roses = list(data_dir.glob('train/nike/*.jpg'))
PIL.Image.open(str(roses[0]))
# 加载指定路径下的图像文件,并使用PIL库打开一个示例图像。

batch_size = 32
img_height = 224
img_width = 224
# 设置图像处理的参数,包括批处理大小(batch_size)和图像的高度和宽度(img_height, img_width)。

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    "D:/others/pycharm/pythonProject/T5/train/",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    "D:/others/pycharm/pythonProject/T5/train/",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)
# 从指定目录中加载训练和验证数据集,并将图像调整为设置的高度和宽度,同时设置批处理大小。

class_names = train_ds.class_names
print(class_names)
# 打印数据集中类别的名称。

plt.figure(figsize=(20, 10))

for images, labels in train_ds.take(1):
    for i in range(20):
        ax = plt.subplot(5, 10, i + 1)
        try:
            plt.imshow(images[i].numpy().astype("uint8"))
            plt.title(class_names[labels[i]])
            plt.axis("off")
        except Exception as e:
            print(f"Error displaying image {i}: {e}")

plt.show()  # 确保调用plt.show()来显示图像
# 可视化一批训练图像,将其显示在网格中,并标注类别名称。

for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break
# 打印一批图像和标签的形状,确认数据集的批处理大小和图像尺寸。

AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
# 优化数据集的加载速度,使用缓存、打乱和预取操作来提高性能。

# 数据增强
data_augmentation = tf.keras.Sequential([
    layers.experimental.preprocessing.RandomFlip('horizontal_and_vertical'),
    layers.experimental.preprocessing.RandomRotation(0.2),
    layers.experimental.preprocessing.RandomZoom(0.2),
])
# 定义数据增强方法,包括随机水平和垂直翻转、随机旋转和随机缩放。

model = models.Sequential([
    data_augmentation,  # 添加数据增强层
    layers.experimental.preprocessing.Rescaling(1. / 255, input_shape=(img_height, img_width, 3)),  # 归一化
    layers.Conv2D(32, (3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(128, (3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Flatten(),
    layers.Dense(512, activation='relu'),
    layers.Dropout(0.5),
    layers.Dense(len(class_names))
])
# 构建卷积神经网络模型,包括数据增强、归一化、卷积层、池化层、全连接层和Dropout层。

# 显式构建模型以避免未构建错误
model.build((None, img_height, img_width, 3))
# 显式构建模型,以避免未构建错误。

model.summary()  # 打印网络结构
# 打印模型的摘要,显示网络结构和参数数量。

# 设置初始学习率
initial_learning_rate = 0.001
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate,
    decay_steps=1000,
    decay_rate=0.9,
    staircase=True
)
# 定义学习率调度器,设置初始学习率,并采用指数衰减的方法来逐步降低学习率。

# 将指数衰减学习率送入优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
# 使用Adam优化器,并将学习率调度器作为参数传入。

model.compile(optimizer=optimizer,
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
# 编译模型,指定优化器、损失函数和评估指标。

epochs = 50

# 保存最佳模型参数
checkpointer = ModelCheckpoint('best_model.h5',
                               monitor='val_accuracy',
                               verbose=1,
                               save_best_only=True,
                               save_weights_only=True)
# 定义检查点回调函数,在验证准确率最佳时保存模型权重。

# 设置早停
earlystopper = EarlyStopping(monitor='val_accuracy',
                             min_delta=0.001,
                             patience=30,  # 增加耐心值
                             verbose=1,
                             restore_best_weights=True)
# 定义早停回调函数,当验证准确率不再提升时提前停止训练,并恢复最佳权重。

history = model.fit(train_ds,
                    validation_data=val_ds,
                    epochs=epochs,
                    callbacks=[checkpointer, earlystopper])
# 训练模型,设置训练和验证数据集,定义训练轮数,并传入回调函数。

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(len(loss))

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
# 绘制训练和验证过程中的准确率和损失曲线,展示模型性能随训练轮数的变化情况。

输出全部内容:

D:\others\anaconda\envs\deep_learning_env\python.exe D:\others\pycharm\pythonProject\T5_Brand_of_sport_recognition.py 
GPU cannot be found, using CPU instead.
图片总数为: 502
2024-07-07 16:17:57.783825: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-07-07 16:17:57.784783: I tensorflow/core/common_runtime/process_util.cc:146] Creating new thread pool with default inter op setting: 2. Tune using inter_op_parallelism_threads for best performance.
Found 502 files belonging to 2 classes.
Found 502 files belonging to 2 classes.
['adidas', 'nike']
(32, 224, 224, 3)
(32,)
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 sequential (Sequential)     (None, 224, 224, 3)       0         
                                                                 
 rescaling (Rescaling)       (None, 224, 224, 3)       0         
                                                                 
 vgg16 (Functional)          (None, 7, 7, 512)         14714688  
                                                                 
 flatten (Flatten)           (None, 25088)             0         
                                                                 
 dense (Dense)               (None, 128)               3211392   
                                                                 
 dense_1 (Dense)             (None, 2)                 258       
                                                                 
=================================================================
Total params: 17,926,338
Trainable params: 3,211,650
Non-trainable params: 14,714,688
_________________________________________________________________
Epoch 1/50
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
16/16 [==============================] - ETA: 0s - loss: 3.1690 - accuracy: 0.5359
Epoch 1: val_accuracy improved from -inf to 0.50199, saving model to best_model.h5
16/16 [==============================] - 76s 5s/step - loss: 3.1690 - accuracy: 0.5359 - val_loss: 1.6623 - val_accuracy: 0.5020
Epoch 2/50
16/16 [==============================] - ETA: 0s - loss: 0.8153 - accuracy: 0.6554
Epoch 2: val_accuracy improved from 0.50199 to 0.64542, saving model to best_model.h5
16/16 [==============================] - 74s 5s/step - loss: 0.8153 - accuracy: 0.6554 - val_loss: 0.7469 - val_accuracy: 0.6454
Epoch 3/50
16/16 [==============================] - ETA: 0s - loss: 0.4668 - accuracy: 0.7928
Epoch 3: val_accuracy improved from 0.64542 to 0.69721, saving model to best_model.h5
16/16 [==============================] - 73s 5s/step - loss: 0.4668 - accuracy: 0.7928 - val_loss: 0.6094 - val_accuracy: 0.6972
Epoch 4/50
16/16 [==============================] - ETA: 0s - loss: 0.3502 - accuracy: 0.8625
Epoch 4: val_accuracy improved from 0.69721 to 0.81275, saving model to best_model.h5
16/16 [==============================] - 74s 5s/step - loss: 0.3502 - accuracy: 0.8625 - val_loss: 0.4189 - val_accuracy: 0.8127
Epoch 5/50
16/16 [==============================] - ETA: 0s - loss: 0.3005 - accuracy: 0.8785
Epoch 5: val_accuracy improved from 0.81275 to 0.85060, saving model to best_model.h5
16/16 [==============================] - 74s 5s/step - loss: 0.3005 - accuracy: 0.8785 - val_loss: 0.3549 - val_accuracy: 0.8506
Epoch 6/50
16/16 [==============================] - ETA: 0s - loss: 0.2425 - accuracy: 0.9143
Epoch 6: val_accuracy did not improve from 0.85060
16/16 [==============================] - 74s 5s/step - loss: 0.2425 - accuracy: 0.9143 - val_loss: 0.3486 - val_accuracy: 0.8327
Epoch 7/50
16/16 [==============================] - ETA: 0s - loss: 0.2191 - accuracy: 0.9263
Epoch 7: val_accuracy improved from 0.85060 to 0.86653, saving model to best_model.h5
16/16 [==============================] - 73s 5s/step - loss: 0.2191 - accuracy: 0.9263 - val_loss: 0.3106 - val_accuracy: 0.8665
Epoch 8/50
16/16 [==============================] - ETA: 0s - loss: 0.2055 - accuracy: 0.9382
Epoch 8: val_accuracy did not improve from 0.86653
16/16 [==============================] - 74s 5s/step - loss: 0.2055 - accuracy: 0.9382 - val_loss: 0.3532 - val_accuracy: 0.8267
Epoch 9/50
16/16 [==============================] - ETA: 0s - loss: 0.2486 - accuracy: 0.9044
Epoch 9: val_accuracy improved from 0.86653 to 0.87251, saving model to best_model.h5
16/16 [==============================] - 75s 5s/step - loss: 0.2486 - accuracy: 0.9044 - val_loss: 0.3037 - val_accuracy: 0.8725
Epoch 10/50
16/16 [==============================] - ETA: 0s - loss: 0.1857 - accuracy: 0.9402
Epoch 10: val_accuracy improved from 0.87251 to 0.87649, saving model to best_model.h5
16/16 [==============================] - 79s 5s/step - loss: 0.1857 - accuracy: 0.9402 - val_loss: 0.2883 - val_accuracy: 0.8765
Epoch 11/50
16/16 [==============================] - ETA: 0s - loss: 0.1697 - accuracy: 0.9422
Epoch 11: val_accuracy did not improve from 0.87649
16/16 [==============================] - 74s 5s/step - loss: 0.1697 - accuracy: 0.9422 - val_loss: 0.3102 - val_accuracy: 0.8566
Epoch 12/50
16/16 [==============================] - ETA: 0s - loss: 0.1545 - accuracy: 0.9522
Epoch 12: val_accuracy did not improve from 0.87649
16/16 [==============================] - 76s 5s/step - loss: 0.1545 - accuracy: 0.9522 - val_loss: 0.3104 - val_accuracy: 0.8685
Epoch 13/50
16/16 [==============================] - ETA: 0s - loss: 0.1471 - accuracy: 0.9641
Epoch 13: val_accuracy did not improve from 0.87649
16/16 [==============================] - 78s 5s/step - loss: 0.1471 - accuracy: 0.9641 - val_loss: 0.2829 - val_accuracy: 0.8705
Epoch 14/50
16/16 [==============================] - ETA: 0s - loss: 0.1438 - accuracy: 0.9602
Epoch 14: val_accuracy improved from 0.87649 to 0.89442, saving model to best_model.h5
16/16 [==============================] - 74s 5s/step - loss: 0.1438 - accuracy: 0.9602 - val_loss: 0.2565 - val_accuracy: 0.8944
Epoch 15/50
16/16 [==============================] - ETA: 0s - loss: 0.1189 - accuracy: 0.9721
Epoch 15: val_accuracy did not improve from 0.89442
16/16 [==============================] - 75s 5s/step - loss: 0.1189 - accuracy: 0.9721 - val_loss: 0.2499 - val_accuracy: 0.8865
Epoch 16/50
16/16 [==============================] - ETA: 0s - loss: 0.1098 - accuracy: 0.9741
Epoch 16: val_accuracy did not improve from 0.89442
16/16 [==============================] - 81s 5s/step - loss: 0.1098 - accuracy: 0.9741 - val_loss: 0.2466 - val_accuracy: 0.8884
Epoch 17/50
16/16 [==============================] - ETA: 0s - loss: 0.1156 - accuracy: 0.9781
Epoch 17: val_accuracy did not improve from 0.89442
16/16 [==============================] - 78s 5s/step - loss: 0.1156 - accuracy: 0.9781 - val_loss: 0.2794 - val_accuracy: 0.8745
Epoch 18/50
16/16 [==============================] - ETA: 0s - loss: 0.0900 - accuracy: 0.9861
Epoch 18: val_accuracy did not improve from 0.89442
16/16 [==============================] - 75s 5s/step - loss: 0.0900 - accuracy: 0.9861 - val_loss: 0.2404 - val_accuracy: 0.8845
Epoch 19/50
16/16 [==============================] - ETA: 0s - loss: 0.0839 - accuracy: 0.9781
Epoch 19: val_accuracy improved from 0.89442 to 0.90637, saving model to best_model.h5
16/16 [==============================] - 77s 5s/step - loss: 0.0839 - accuracy: 0.9781 - val_loss: 0.2263 - val_accuracy: 0.9064
Epoch 20/50
16/16 [==============================] - ETA: 0s - loss: 0.0802 - accuracy: 0.9861
Epoch 20: val_accuracy did not improve from 0.90637
16/16 [==============================] - 76s 5s/step - loss: 0.0802 - accuracy: 0.9861 - val_loss: 0.2437 - val_accuracy: 0.8964
Epoch 21/50
16/16 [==============================] - ETA: 0s - loss: 0.0674 - accuracy: 0.9880
Epoch 21: val_accuracy did not improve from 0.90637
16/16 [==============================] - 79s 5s/step - loss: 0.0674 - accuracy: 0.9880 - val_loss: 0.2183 - val_accuracy: 0.9064
Epoch 22/50
16/16 [==============================] - ETA: 0s - loss: 0.0661 - accuracy: 0.9900
Epoch 22: val_accuracy did not improve from 0.90637
16/16 [==============================] - 78s 5s/step - loss: 0.0661 - accuracy: 0.9900 - val_loss: 0.2448 - val_accuracy: 0.8825
Epoch 23/50
16/16 [==============================] - ETA: 0s - loss: 0.0632 - accuracy: 0.9920
Epoch 23: val_accuracy did not improve from 0.90637
16/16 [==============================] - 77s 5s/step - loss: 0.0632 - accuracy: 0.9920 - val_loss: 0.2235 - val_accuracy: 0.8964
Epoch 24/50
16/16 [==============================] - ETA: 0s - loss: 0.0466 - accuracy: 1.0000
Epoch 24: val_accuracy did not improve from 0.90637
16/16 [==============================] - 77s 5s/step - loss: 0.0466 - accuracy: 1.0000 - val_loss: 0.2267 - val_accuracy: 0.8964
Epoch 25/50
16/16 [==============================] - ETA: 0s - loss: 0.0672 - accuracy: 0.9841
Epoch 25: val_accuracy did not improve from 0.90637
16/16 [==============================] - 76s 5s/step - loss: 0.0672 - accuracy: 0.9841 - val_loss: 0.2306 - val_accuracy: 0.8944
Epoch 26/50
16/16 [==============================] - ETA: 0s - loss: 0.0594 - accuracy: 0.9900
Epoch 26: val_accuracy did not improve from 0.90637
16/16 [==============================] - 77s 5s/step - loss: 0.0594 - accuracy: 0.9900 - val_loss: 0.2322 - val_accuracy: 0.8865
Epoch 27/50
16/16 [==============================] - ETA: 0s - loss: 0.0478 - accuracy: 0.9960
Epoch 27: val_accuracy did not improve from 0.90637
16/16 [==============================] - 77s 5s/step - loss: 0.0478 - accuracy: 0.9960 - val_loss: 0.2220 - val_accuracy: 0.9044
Epoch 28/50
16/16 [==============================] - ETA: 0s - loss: 0.0479 - accuracy: 0.9920
Epoch 28: val_accuracy did not improve from 0.90637
16/16 [==============================] - 76s 5s/step - loss: 0.0479 - accuracy: 0.9920 - val_loss: 0.2369 - val_accuracy: 0.8944
Epoch 29/50
16/16 [==============================] - ETA: 0s - loss: 0.0420 - accuracy: 0.9920
Epoch 29: val_accuracy did not improve from 0.90637
16/16 [==============================] - 75s 5s/step - loss: 0.0420 - accuracy: 0.9920 - val_loss: 0.2372 - val_accuracy: 0.8865
Epoch 29: early stopping

进程已结束,退出代码为 0

二、学习积累

2.1 数据增强

data_augmentation = tf.keras.Sequential([
    layers.experimental.preprocessing.RandomFlip('horizontal_and_vertical'),
    layers.experimental.preprocessing.RandomRotation(0.2),
    layers.experimental.preprocessing.RandomZoom(0.2),
])

tf.keras.Sequential构建一个顺序模型,这里用于数据增强Sequential 是 Keras 中的一个模型类型,它表示一系列顺序的层。

RandomFlip('horizontal_and_vertical'):随机水平和垂直翻转图像。

RandomRotation(0.2):随机旋转图像,旋转角度范围为[-20%, 20%]。

RandomZoom(0.2):随机缩放图像,缩放比例范围为[80%, 120%]。

数据增强的目的是,通过对训练图像进行随机变换,比如旋转、翻转、缩放等等方式来增加数据集的多样性,有助于防止模型过拟合,提高模型的泛化能力。

2.2 显式构建模型

model.build((None, img_height, img_width, 3))

model.build(input_shape):显式构建模型,其中input_shape定义了模型输入的形状。这里的(None, img_height, img_width, 3)表示输入是任意数量的RGB图像,每个图像大小为img_height x img_width

显式的构建模型是为了在使用keras功能的时候,比如打印模型摘要、保存模型结构等等之前,确保模型已经正确构建了,以防止报错。如果模型没有显式构建,某些操作可能会报错。

2.3回调函数

# 保存最佳模型参数
checkpointer = ModelCheckpoint('best_model.h5',
                               monitor='val_accuracy',
                               verbose=1,
                               save_best_only=True,
                               save_weights_only=True)
# 定义检查点回调函数,在验证准确率最佳时保存模型权重。

# 设置早停
earlystopper = EarlyStopping(monitor='val_accuracy',
                             min_delta=0.001,
                             patience=30,  # 增加耐心值
                             verbose=1,
                             restore_best_weights=True)
# 定义早停回调函数,当验证准确率不再提升时提前停止训练,并恢复最佳权重。

ModelCheckpoint:用于在训练过程中保存模型。参数:

  • 'best_model.h5':保存模型的文件名。
  • monitor='val_accuracy':监控验证准确率。
  • verbose=1:训练过程中输出日志信息。
  • save_best_only=True:只保存验证准确率最高的模型。
  • save_weights_only=True:只保存模型的权重,而不是整个模型结构。

EarlyStopping:用于在验证准确率不再提升时提前停止训练。参数:

  • monitor='val_accuracy':监控验证准确率。
  • min_delta=0.001:只有当验证准确率增加值小于这个阈值时,才认为准确率没有提升。
  • patience=30:在验证准确率连续30次迭代没有提升时停止训练。
  • verbose=1:训练过程中输出日志信息。
  • restore_best_weights=True:停止训练后,恢复验证准确率最高时的模型权重。

回调函数是指在训练过程中自动执行的函数,用来动态调整训练过程,保存模型或早停等等。常用的回调函数是ModelCheckpointEarlyStopping。

  • 7
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值