>- **🍨 本文为[🔗365天深度学习训练营](https://mp.weixin.qq.com/s/0dvHCaOoFnW8SCp3JpzKxg) 中的学习记录博客**
>- **🍖 原作者:[K同学啊](https://mtyjkh.blog.csdn.net/)**
一、模型跑通
1.1环境配置
编译器:pycharm community
语言环境:Python 3.9.19
深度学习环境:TensorFlow 2.9.0
1.2流程
-
导入必要的库:
- 导入Python标准库和第三方库,以便进行数据处理、模型构建和训练等操作。
-
设置环境变量:
- 设置环境变量以避免OpenMP错误,这是为了解决某些系统上的并发编程问题。
-
GPU配置:
- 检查系统中是否有可用的GPU。如果有,配置GPU内存按需增长并设置可见设备;如果没有,输出提示使用CPU。
-
加载数据集:
- 设置数据集的本地目录路径,并统计目录下所有图像文件的总数。
-
显示示例图像:
- 加载并显示一个示例图像,以确保图像路径正确。
-
设置图像处理参数:
- 设置图像批处理大小(batch_size)和图像尺寸(img_height, img_width)。
-
加载训练和验证数据集:
- 从指定目录中加载训练和验证数据集,并将图像调整为设定的尺寸,同时设置批处理大小。
-
打印类别名称:
- 获取并打印数据集中所有类别的名称。
-
可视化训练图像:
- 可视化一批训练图像,并标注类别名称,以确保数据加载和标签正确。
-
打印数据批次形状:
- 打印一批图像和标签的形状,确认数据集的批处理大小和图像尺寸。
-
优化数据集加载速度:
- 使用缓存、打乱和预取操作来提高数据加载的性能。
-
定义数据增强方法:
- 定义数据增强方法,包括随机水平和垂直翻转、随机旋转和随机缩放,以增加训练数据的多样性。
-
构建模型:
- 构建卷积神经网络模型,包括数据增强层、归一化层、卷积层、池化层、全连接层和Dropout层。
-
显式构建模型:
- 显式构建模型以避免未构建错误,并打印模型的摘要以查看网络结构和参数数量。
-
设置学习率调度器和优化器:
- 定义学习率调度器,设置初始学习率,并采用指数衰减的方法来逐步降低学习率。将学习率调度器传入优化器。
-
编译模型:
- 编译模型,指定优化器、损失函数和评估指标。
-
定义回调函数:
- 定义检查点回调函数,以在验证准确率最佳时保存模型权重。定义早停回调函数,当验证准确率不再提升时提前停止训练,并恢复最佳权重。
-
训练模型:
- 使用训练数据和验证数据训练模型,设置训练轮数,并传入回调函数。
-
可视化训练和验证过程:
- 绘制训练和验证过程中的准确率和损失曲线,展示模型性能随训练轮数的变化情况。
-
总结和分析:
- 通过可视化结果和模型性能评估,总结训练过程中的表现和模型的最终效果。
1.3完整代码
import os
import tensorflow as tf
import pathlib
import PIL.Image
import matplotlib.pyplot as plt
from tensorflow.keras import layers, models
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
# 导入必要的库,其中包括操作系统接口(os)、TensorFlow、路径处理库(pathlib)、图像处理库(PIL)、绘图库(matplotlib)以及Keras中用于构建和训练模型的各个模块。
# 设置环境变量以避免OpenMP错误
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
# 这行代码用于设置环境变量以避免在某些系统上可能出现的OpenMP错误。
# GPU配置
gpus = tf.config.list_physical_devices('GPU')
if gpus:
gpu0 = gpus[0]
tf.config.experimental.set_memory_growth(gpu0, True) # 设置显存按需增长
tf.config.set_visible_devices([gpu0], "GPU")
print("GPU available.")
else:
print("GPU cannot be found, using CPU instead.")
# 检查系统中是否有可用的GPU,如果有,设置GPU内存按需增长,并将其设为可见设备;如果没有,输出提示使用CPU。
data_dir = "D:/others/pycharm/pythonProject/T5"
data_dir = pathlib.Path(data_dir)
image_count = len(list(data_dir.glob('*/*/*.jpg')))
print("图片总数为:", image_count)
# 设置数据集的目录路径,并统计该目录下所有jpg格式图像的总数。
roses = list(data_dir.glob('train/nike/*.jpg'))
PIL.Image.open(str(roses[0]))
# 加载指定路径下的图像文件,并使用PIL库打开一个示例图像。
batch_size = 32
img_height = 224
img_width = 224
# 设置图像处理的参数,包括批处理大小(batch_size)和图像的高度和宽度(img_height, img_width)。
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
"D:/others/pycharm/pythonProject/T5/train/",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
"D:/others/pycharm/pythonProject/T5/train/",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
# 从指定目录中加载训练和验证数据集,并将图像调整为设置的高度和宽度,同时设置批处理大小。
class_names = train_ds.class_names
print(class_names)
# 打印数据集中类别的名称。
plt.figure(figsize=(20, 10))
for images, labels in train_ds.take(1):
for i in range(20):
ax = plt.subplot(5, 10, i + 1)
try:
plt.imshow(images[i].numpy().astype("uint8"))
plt.title(class_names[labels[i]])
plt.axis("off")
except Exception as e:
print(f"Error displaying image {i}: {e}")
plt.show() # 确保调用plt.show()来显示图像
# 可视化一批训练图像,将其显示在网格中,并标注类别名称。
for image_batch, labels_batch in train_ds:
print(image_batch.shape)
print(labels_batch.shape)
break
# 打印一批图像和标签的形状,确认数据集的批处理大小和图像尺寸。
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
# 优化数据集的加载速度,使用缓存、打乱和预取操作来提高性能。
# 数据增强
data_augmentation = tf.keras.Sequential([
layers.experimental.preprocessing.RandomFlip('horizontal_and_vertical'),
layers.experimental.preprocessing.RandomRotation(0.2),
layers.experimental.preprocessing.RandomZoom(0.2),
])
# 定义数据增强方法,包括随机水平和垂直翻转、随机旋转和随机缩放。
model = models.Sequential([
data_augmentation, # 添加数据增强层
layers.experimental.preprocessing.Rescaling(1. / 255, input_shape=(img_height, img_width, 3)), # 归一化
layers.Conv2D(32, (3, 3), activation='relu'),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(64, (3, 3), activation='relu'),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(128, (3, 3), activation='relu'),
layers.MaxPooling2D((2, 2)),
layers.Flatten(),
layers.Dense(512, activation='relu'),
layers.Dropout(0.5),
layers.Dense(len(class_names))
])
# 构建卷积神经网络模型,包括数据增强、归一化、卷积层、池化层、全连接层和Dropout层。
# 显式构建模型以避免未构建错误
model.build((None, img_height, img_width, 3))
# 显式构建模型,以避免未构建错误。
model.summary() # 打印网络结构
# 打印模型的摘要,显示网络结构和参数数量。
# 设置初始学习率
initial_learning_rate = 0.001
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate,
decay_steps=1000,
decay_rate=0.9,
staircase=True
)
# 定义学习率调度器,设置初始学习率,并采用指数衰减的方法来逐步降低学习率。
# 将指数衰减学习率送入优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
# 使用Adam优化器,并将学习率调度器作为参数传入。
model.compile(optimizer=optimizer,
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# 编译模型,指定优化器、损失函数和评估指标。
epochs = 50
# 保存最佳模型参数
checkpointer = ModelCheckpoint('best_model.h5',
monitor='val_accuracy',
verbose=1,
save_best_only=True,
save_weights_only=True)
# 定义检查点回调函数,在验证准确率最佳时保存模型权重。
# 设置早停
earlystopper = EarlyStopping(monitor='val_accuracy',
min_delta=0.001,
patience=30, # 增加耐心值
verbose=1,
restore_best_weights=True)
# 定义早停回调函数,当验证准确率不再提升时提前停止训练,并恢复最佳权重。
history = model.fit(train_ds,
validation_data=val_ds,
epochs=epochs,
callbacks=[checkpointer, earlystopper])
# 训练模型,设置训练和验证数据集,定义训练轮数,并传入回调函数。
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(len(loss))
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
# 绘制训练和验证过程中的准确率和损失曲线,展示模型性能随训练轮数的变化情况。
输出全部内容:
D:\others\anaconda\envs\deep_learning_env\python.exe D:\others\pycharm\pythonProject\T5_Brand_of_sport_recognition.py
GPU cannot be found, using CPU instead.
图片总数为: 502
2024-07-07 16:17:57.783825: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-07-07 16:17:57.784783: I tensorflow/core/common_runtime/process_util.cc:146] Creating new thread pool with default inter op setting: 2. Tune using inter_op_parallelism_threads for best performance.
Found 502 files belonging to 2 classes.
Found 502 files belonging to 2 classes.
['adidas', 'nike']
(32, 224, 224, 3)
(32,)
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
Model: "sequential_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
sequential (Sequential) (None, 224, 224, 3) 0
rescaling (Rescaling) (None, 224, 224, 3) 0
vgg16 (Functional) (None, 7, 7, 512) 14714688
flatten (Flatten) (None, 25088) 0
dense (Dense) (None, 128) 3211392
dense_1 (Dense) (None, 2) 258
=================================================================
Total params: 17,926,338
Trainable params: 3,211,650
Non-trainable params: 14,714,688
_________________________________________________________________
Epoch 1/50
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
16/16 [==============================] - ETA: 0s - loss: 3.1690 - accuracy: 0.5359
Epoch 1: val_accuracy improved from -inf to 0.50199, saving model to best_model.h5
16/16 [==============================] - 76s 5s/step - loss: 3.1690 - accuracy: 0.5359 - val_loss: 1.6623 - val_accuracy: 0.5020
Epoch 2/50
16/16 [==============================] - ETA: 0s - loss: 0.8153 - accuracy: 0.6554
Epoch 2: val_accuracy improved from 0.50199 to 0.64542, saving model to best_model.h5
16/16 [==============================] - 74s 5s/step - loss: 0.8153 - accuracy: 0.6554 - val_loss: 0.7469 - val_accuracy: 0.6454
Epoch 3/50
16/16 [==============================] - ETA: 0s - loss: 0.4668 - accuracy: 0.7928
Epoch 3: val_accuracy improved from 0.64542 to 0.69721, saving model to best_model.h5
16/16 [==============================] - 73s 5s/step - loss: 0.4668 - accuracy: 0.7928 - val_loss: 0.6094 - val_accuracy: 0.6972
Epoch 4/50
16/16 [==============================] - ETA: 0s - loss: 0.3502 - accuracy: 0.8625
Epoch 4: val_accuracy improved from 0.69721 to 0.81275, saving model to best_model.h5
16/16 [==============================] - 74s 5s/step - loss: 0.3502 - accuracy: 0.8625 - val_loss: 0.4189 - val_accuracy: 0.8127
Epoch 5/50
16/16 [==============================] - ETA: 0s - loss: 0.3005 - accuracy: 0.8785
Epoch 5: val_accuracy improved from 0.81275 to 0.85060, saving model to best_model.h5
16/16 [==============================] - 74s 5s/step - loss: 0.3005 - accuracy: 0.8785 - val_loss: 0.3549 - val_accuracy: 0.8506
Epoch 6/50
16/16 [==============================] - ETA: 0s - loss: 0.2425 - accuracy: 0.9143
Epoch 6: val_accuracy did not improve from 0.85060
16/16 [==============================] - 74s 5s/step - loss: 0.2425 - accuracy: 0.9143 - val_loss: 0.3486 - val_accuracy: 0.8327
Epoch 7/50
16/16 [==============================] - ETA: 0s - loss: 0.2191 - accuracy: 0.9263
Epoch 7: val_accuracy improved from 0.85060 to 0.86653, saving model to best_model.h5
16/16 [==============================] - 73s 5s/step - loss: 0.2191 - accuracy: 0.9263 - val_loss: 0.3106 - val_accuracy: 0.8665
Epoch 8/50
16/16 [==============================] - ETA: 0s - loss: 0.2055 - accuracy: 0.9382
Epoch 8: val_accuracy did not improve from 0.86653
16/16 [==============================] - 74s 5s/step - loss: 0.2055 - accuracy: 0.9382 - val_loss: 0.3532 - val_accuracy: 0.8267
Epoch 9/50
16/16 [==============================] - ETA: 0s - loss: 0.2486 - accuracy: 0.9044
Epoch 9: val_accuracy improved from 0.86653 to 0.87251, saving model to best_model.h5
16/16 [==============================] - 75s 5s/step - loss: 0.2486 - accuracy: 0.9044 - val_loss: 0.3037 - val_accuracy: 0.8725
Epoch 10/50
16/16 [==============================] - ETA: 0s - loss: 0.1857 - accuracy: 0.9402
Epoch 10: val_accuracy improved from 0.87251 to 0.87649, saving model to best_model.h5
16/16 [==============================] - 79s 5s/step - loss: 0.1857 - accuracy: 0.9402 - val_loss: 0.2883 - val_accuracy: 0.8765
Epoch 11/50
16/16 [==============================] - ETA: 0s - loss: 0.1697 - accuracy: 0.9422
Epoch 11: val_accuracy did not improve from 0.87649
16/16 [==============================] - 74s 5s/step - loss: 0.1697 - accuracy: 0.9422 - val_loss: 0.3102 - val_accuracy: 0.8566
Epoch 12/50
16/16 [==============================] - ETA: 0s - loss: 0.1545 - accuracy: 0.9522
Epoch 12: val_accuracy did not improve from 0.87649
16/16 [==============================] - 76s 5s/step - loss: 0.1545 - accuracy: 0.9522 - val_loss: 0.3104 - val_accuracy: 0.8685
Epoch 13/50
16/16 [==============================] - ETA: 0s - loss: 0.1471 - accuracy: 0.9641
Epoch 13: val_accuracy did not improve from 0.87649
16/16 [==============================] - 78s 5s/step - loss: 0.1471 - accuracy: 0.9641 - val_loss: 0.2829 - val_accuracy: 0.8705
Epoch 14/50
16/16 [==============================] - ETA: 0s - loss: 0.1438 - accuracy: 0.9602
Epoch 14: val_accuracy improved from 0.87649 to 0.89442, saving model to best_model.h5
16/16 [==============================] - 74s 5s/step - loss: 0.1438 - accuracy: 0.9602 - val_loss: 0.2565 - val_accuracy: 0.8944
Epoch 15/50
16/16 [==============================] - ETA: 0s - loss: 0.1189 - accuracy: 0.9721
Epoch 15: val_accuracy did not improve from 0.89442
16/16 [==============================] - 75s 5s/step - loss: 0.1189 - accuracy: 0.9721 - val_loss: 0.2499 - val_accuracy: 0.8865
Epoch 16/50
16/16 [==============================] - ETA: 0s - loss: 0.1098 - accuracy: 0.9741
Epoch 16: val_accuracy did not improve from 0.89442
16/16 [==============================] - 81s 5s/step - loss: 0.1098 - accuracy: 0.9741 - val_loss: 0.2466 - val_accuracy: 0.8884
Epoch 17/50
16/16 [==============================] - ETA: 0s - loss: 0.1156 - accuracy: 0.9781
Epoch 17: val_accuracy did not improve from 0.89442
16/16 [==============================] - 78s 5s/step - loss: 0.1156 - accuracy: 0.9781 - val_loss: 0.2794 - val_accuracy: 0.8745
Epoch 18/50
16/16 [==============================] - ETA: 0s - loss: 0.0900 - accuracy: 0.9861
Epoch 18: val_accuracy did not improve from 0.89442
16/16 [==============================] - 75s 5s/step - loss: 0.0900 - accuracy: 0.9861 - val_loss: 0.2404 - val_accuracy: 0.8845
Epoch 19/50
16/16 [==============================] - ETA: 0s - loss: 0.0839 - accuracy: 0.9781
Epoch 19: val_accuracy improved from 0.89442 to 0.90637, saving model to best_model.h5
16/16 [==============================] - 77s 5s/step - loss: 0.0839 - accuracy: 0.9781 - val_loss: 0.2263 - val_accuracy: 0.9064
Epoch 20/50
16/16 [==============================] - ETA: 0s - loss: 0.0802 - accuracy: 0.9861
Epoch 20: val_accuracy did not improve from 0.90637
16/16 [==============================] - 76s 5s/step - loss: 0.0802 - accuracy: 0.9861 - val_loss: 0.2437 - val_accuracy: 0.8964
Epoch 21/50
16/16 [==============================] - ETA: 0s - loss: 0.0674 - accuracy: 0.9880
Epoch 21: val_accuracy did not improve from 0.90637
16/16 [==============================] - 79s 5s/step - loss: 0.0674 - accuracy: 0.9880 - val_loss: 0.2183 - val_accuracy: 0.9064
Epoch 22/50
16/16 [==============================] - ETA: 0s - loss: 0.0661 - accuracy: 0.9900
Epoch 22: val_accuracy did not improve from 0.90637
16/16 [==============================] - 78s 5s/step - loss: 0.0661 - accuracy: 0.9900 - val_loss: 0.2448 - val_accuracy: 0.8825
Epoch 23/50
16/16 [==============================] - ETA: 0s - loss: 0.0632 - accuracy: 0.9920
Epoch 23: val_accuracy did not improve from 0.90637
16/16 [==============================] - 77s 5s/step - loss: 0.0632 - accuracy: 0.9920 - val_loss: 0.2235 - val_accuracy: 0.8964
Epoch 24/50
16/16 [==============================] - ETA: 0s - loss: 0.0466 - accuracy: 1.0000
Epoch 24: val_accuracy did not improve from 0.90637
16/16 [==============================] - 77s 5s/step - loss: 0.0466 - accuracy: 1.0000 - val_loss: 0.2267 - val_accuracy: 0.8964
Epoch 25/50
16/16 [==============================] - ETA: 0s - loss: 0.0672 - accuracy: 0.9841
Epoch 25: val_accuracy did not improve from 0.90637
16/16 [==============================] - 76s 5s/step - loss: 0.0672 - accuracy: 0.9841 - val_loss: 0.2306 - val_accuracy: 0.8944
Epoch 26/50
16/16 [==============================] - ETA: 0s - loss: 0.0594 - accuracy: 0.9900
Epoch 26: val_accuracy did not improve from 0.90637
16/16 [==============================] - 77s 5s/step - loss: 0.0594 - accuracy: 0.9900 - val_loss: 0.2322 - val_accuracy: 0.8865
Epoch 27/50
16/16 [==============================] - ETA: 0s - loss: 0.0478 - accuracy: 0.9960
Epoch 27: val_accuracy did not improve from 0.90637
16/16 [==============================] - 77s 5s/step - loss: 0.0478 - accuracy: 0.9960 - val_loss: 0.2220 - val_accuracy: 0.9044
Epoch 28/50
16/16 [==============================] - ETA: 0s - loss: 0.0479 - accuracy: 0.9920
Epoch 28: val_accuracy did not improve from 0.90637
16/16 [==============================] - 76s 5s/step - loss: 0.0479 - accuracy: 0.9920 - val_loss: 0.2369 - val_accuracy: 0.8944
Epoch 29/50
16/16 [==============================] - ETA: 0s - loss: 0.0420 - accuracy: 0.9920
Epoch 29: val_accuracy did not improve from 0.90637
16/16 [==============================] - 75s 5s/step - loss: 0.0420 - accuracy: 0.9920 - val_loss: 0.2372 - val_accuracy: 0.8865
Epoch 29: early stopping
进程已结束,退出代码为 0
二、学习积累
2.1 数据增强
data_augmentation = tf.keras.Sequential([
layers.experimental.preprocessing.RandomFlip('horizontal_and_vertical'),
layers.experimental.preprocessing.RandomRotation(0.2),
layers.experimental.preprocessing.RandomZoom(0.2),
])
tf.keras.Sequential
:构建一个顺序模型,这里用于数据增强。Sequential
是 Keras 中的一个模型类型,它表示一系列顺序的层。
RandomFlip('horizontal_and_vertical')
:随机水平和垂直翻转图像。
RandomRotation(0.2)
:随机旋转图像,旋转角度范围为[-20%, 20%]。
RandomZoom(0.2)
:随机缩放图像,缩放比例范围为[80%, 120%]。
数据增强的目的是,通过对训练图像进行随机变换,比如旋转、翻转、缩放等等方式来增加数据集的多样性,有助于防止模型过拟合,提高模型的泛化能力。
2.2 显式构建模型
model.build((None, img_height, img_width, 3))
model.build(input_shape)
:显式构建模型,其中input_shape
定义了模型输入的形状。这里的(None, img_height, img_width, 3)
表示输入是任意数量的RGB图像,每个图像大小为img_height x img_width
。
显式的构建模型是为了在使用keras功能的时候,比如打印模型摘要、保存模型结构等等之前,确保模型已经正确构建了,以防止报错。如果模型没有显式构建,某些操作可能会报错。
2.3回调函数
# 保存最佳模型参数
checkpointer = ModelCheckpoint('best_model.h5',
monitor='val_accuracy',
verbose=1,
save_best_only=True,
save_weights_only=True)
# 定义检查点回调函数,在验证准确率最佳时保存模型权重。
# 设置早停
earlystopper = EarlyStopping(monitor='val_accuracy',
min_delta=0.001,
patience=30, # 增加耐心值
verbose=1,
restore_best_weights=True)
# 定义早停回调函数,当验证准确率不再提升时提前停止训练,并恢复最佳权重。
ModelCheckpoint
:用于在训练过程中保存模型。参数:
'best_model.h5'
:保存模型的文件名。monitor='val_accuracy'
:监控验证准确率。verbose=1
:训练过程中输出日志信息。save_best_only=True
:只保存验证准确率最高的模型。save_weights_only=True
:只保存模型的权重,而不是整个模型结构。
EarlyStopping
:用于在验证准确率不再提升时提前停止训练。参数:
monitor='val_accuracy'
:监控验证准确率。min_delta=0.001
:只有当验证准确率增加值小于这个阈值时,才认为准确率没有提升。patience=30
:在验证准确率连续30次迭代没有提升时停止训练。verbose=1
:训练过程中输出日志信息。restore_best_weights=True
:停止训练后,恢复验证准确率最高时的模型权重。
回调函数是指在训练过程中自动执行的函数,用来动态调整训练过程,保存模型或早停等等。常用的回调函数是ModelCheckpoint
和EarlyStopping。