深度学习Week12——利用TensorFlow实现好莱坞明星识别

文章目录
深度学习Week12——利用TensorFlow实现好莱坞明星识别
一、前言
二、我的环境
三、前期工作
1、配置环境
2、导入数据
四、数据预处理
1、加载数据
2、可视化数据
3、检查数据
4、配置数据集
五、构建CNN模型
1、设置动态学习率
2、早停与保存最佳模型参数
五、编译模型
六、训练模型
七、预测与评估
1、Accuracy图
2、指定图像预测
八、拓展

一、前言

本篇内容分为两个部分,前面部分是学习K同学给的算法知识点以及复现,后半部分是自己的拓展与未解决的问题

二、我的环境

  • 电脑系统:Windows 10
  • 语言环境:Python 3.8.0
  • 编译器:Pycharm2023.2.3
    深度学习环境:TensorFlow
    显卡及显存:RTX 3060 8G

三、前期工作

1、导入库并配置环境

from tensorflow       import keras
from tensorflow.keras import layers,models
import os, PIL, pathlib
import matplotlib.pyplot as plt
import tensorflow        as tf

gpus = tf.config.list_physical_devices("GPU")

if gpus:
    gpu0 = gpus[0]                                        #如果有多个GPU,仅使用第0个GPU
    tf.config.experimental.set_memory_growth(gpu0, True)  #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpu0],"GPU")
    
gpus

这一步与pytorch第一步类似,我们在写神经网络程序前无论是选择pytorch还是tensorflow都应该配置好gpu环境(如果有gpu的话)

2、 导入数据

导入所有好莱坞明星照片数据,依次分别为训练集图片(train_images)、训练集标签(train_labels)、测试集图片(test_images)、测试集标签(test_labels),数据集来源于K同学啊的网盘:数据集

data_dir = "E:\Deep_Learning\Data\Week6"

data_dir = pathlib.Path(data_dir)

image_count = len(list(data_dir.glob('*/*.jpg')))

print("图片总数为:",image_count)

#查看第一张图片:
roses = list(data_dir.glob('Jennifer Lawrence/*.jpg'))
PIL.Image.open(str(roses[8]))

图片总数为: 1800
在这里插入图片描述

四、数据预处理

1、加载数据

batch_size = 32
img_height = 224
img_width = 224

使用image_dataset_from_directory方法将磁盘中的数据加载到tf.data.Dataset

tf.keras.preprocessing.image_dataset_from_directory()会将文件夹中的数据加载到tf.data.Dataset中,且加载的同时会打乱数据。

  • class_names
  • validation_split: 0和1之间的可选浮点数,可保留一部分数据用于验证。
  • subset: training或validation之一。仅在设置validation_split时使用。
  • seed: 用于shuffle和转换的可选随机种子。
  • batch_size: 数据批次的大小。默认值:32
  • image_size: 从磁盘读取数据后将其重新调整大小。默认:(256,256)。由于管道处理的图像批次必须具有相同的大小,因此该参数必须提供。
"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.1,
    subset="training",
    label_mode = "categorical",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

输出:

Found 1800 files belonging to 17 classes.
Using 1620 files for training.
  1. 验证集并没有参与训练过程梯度下降过程的,狭义上来讲是没有参与模型的参数训练更新的。
  2. 但是广义上来讲,验证集存在的意义确实参与了一个“人工调参”的过程,我们根据每一个epoch训练之后模型在valid data上的表现来决定是否需要训练进行early stop,或者根据这个过程模型的性能变化来调整模型的超参数,如学习率,batch_size等等。
  3. 因此,我们也可以认为,验证集也参与了训练,但是并没有使得模型去overfit验证集
"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.1,
    subset="validation",
    label_mode = "categorical",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

输出:

Found 1800 files belonging to 1 classes.
Using 180 files for validation.

我们可以通过class_names输出数据集的标签。标签将按字母顺序对应于目录名称。

class_names = train_ds.class_names
print(class_names)

[‘Angelina Jolie’, ‘Brad Pitt’, ‘Denzel Washington’, ‘Hugh Jackman’, ‘Jennifer Lawrence’, ‘Johnny Depp’, ‘Kate Winslet’, ‘Leonardo DiCaprio’, ‘Megan Fox’, ‘Natalie Portman’, ‘Nicole Kidman’, ‘Robert Downey Jr’, ‘Sandra Bullock’, ‘Scarlett Johansson’, ‘Tom Cruise’, ‘Tom Hanks’, ‘Will Smith’]

2、数据可视化

# 查看前20个图片
plt.figure(figsize=(20, 10))

for images, labels in train_ds.take(1):
    for i in range(20):
        ax = plt.subplot(5, 10, i + 1)

        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        
        plt.axis("off")

在这里插入图片描述

3、再次检查数据

for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break

(32, 224, 224, 3)
(32,)
Image_batch是形状的张量(32,224,224,3)。这是一批形状224x224x3的32张图片(最后一维指的是彩色通道RGB。
Label_batch是形状(32,)的张量,这些标签对应32张图片

4、配置数据集

  • shuffle():打乱数据
  • prefetch():预取数据,加速运行
  • cache():将数据集缓存到内存当中,加速运行

如果不使用prefetch(),CPU 和 GPU/TPU 在大部分时间都处于空闲状态:
使用前
使用prefetch()可显著减少空闲时间:
在这里插入图片描述

AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

五 、构建CNN模型

卷积神经网络(CNN)的输入是张量 (Tensor) 形式的(image_height, image_width, color_channels),包含了图像高度、宽度及颜色信息。不需要输入batch size。color_channels 为 (R,G,B) 分别对应 RGB 的三个颜色通道(color channel)。在此示例中,我们的 CNN 输入形状是(224, 224, 3)。我们需要在声明第一层时将形状赋值给参数input_shape

这是一个重难点,在构建模型之前,我们先来看一看各层有什么作用以及网络结构图
在这里插入图片描述

  1. 输入层:
    输入层负责接收原始数据,将数据传递到网络中的第一层。
  2. 卷积层:
    卷积层使用卷积核对输入数据进行滤波操作,以提取图像中的特征。
  3. 池化层:
    池化层用于对卷积层的输出进行下采样,以减少数据的维度和计算量。
  4. Flatten层:
    Flatten层用于将多维的输入数据(如卷积层的输出)压缩成一维的向量。
    常用在卷积层到全连接层的过渡,将卷积层输出的特征图展平成一维向量,以便输入到全连接层中进行分类或回归等任务。
  5. 全连接层:
    全连接层起到“特征提取器”的作用,将前面层的特征表示映射到输出层。
  6. 输出层:
    输出层负责输出模型的预测结果。

ReLu函数作为激活励函数可以增强判定函数和整个神经网络的非线性特性,而本身并不会改变卷积层;
相比其它函数来说,ReLU函数更受青睐,这是因为它可以将神经网络的训练速度提升数倍,而并不会对模型的泛化准确度造成显著影响。

num_classes = 2

"""
关于卷积核的计算不懂的可以参考文章:https://blog.csdn.net/qq_38251616/article/details/114278995

layers.Dropout(0.4) 作用是防止过拟合,提高模型的泛化能力。

关于Dropout层的更多介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/115826689
"""

model = models.Sequential([
    layers.experimental.preprocessing.Rescaling(1./255, input_shape=(img_height, img_width, 3)),
    
    layers.Conv2D(16, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)), # 卷积层1,卷积核3*3  
    layers.AveragePooling2D((2, 2)),               # 池化层1,2*2采样
    layers.Conv2D(32, (3, 3), activation='relu'),  # 卷积层2,卷积核3*3
    layers.AveragePooling2D((2, 2)),               # 池化层2,2*2采样
    layers.Dropout(0.3),  
    layers.Conv2D(64, (3, 3), activation='relu'),  # 卷积层3,卷积核3*3
    layers.Dropout(0.4),  
    
    layers.Flatten(),                       # Flatten层,连接卷积层与全连接层
    layers.Dense(128, activation='relu'),   # 全连接层,特征进一步提取
    layers.Dense(num_classes)               # 输出层,输出预期结果
])

model.summary()  # 打印网络结构
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 rescaling (Rescaling)       (None, 224, 224, 3)       0         
                                                                 
 conv2d (Conv2D)             (None, 222, 222, 16)      448       
                                                                 
 average_pooling2d (Average  (None, 111, 111, 16)      0         
 Pooling2D)                                                      
                                                                 
 conv2d_1 (Conv2D)           (None, 109, 109, 32)      4640      
                                                                 
 average_pooling2d_1 (Avera  (None, 54, 54, 32)        0         
 gePooling2D)                                                    
                                                                 
 dropout (Dropout)           (None, 54, 54, 32)        0         
                                                                 
 conv2d_2 (Conv2D)           (None, 52, 52, 64)        18496     
                                                                 
 average_pooling2d_2 (Avera  (None, 26, 26, 64)        0         
 gePooling2D)                                                    
                                                                 
 dropout_1 (Dropout)         (None, 26, 26, 64)        0         
                                                                 
 conv2d_3 (Conv2D)           (None, 24, 24, 128)       73856     
                                                                 
 dropout_2 (Dropout)         (None, 24, 24, 128)       0         
                                                                 
 flatten (Flatten)           (None, 73728)             0         
                                                                 
 dense (Dense)               (None, 128)               9437312   
                                                                 
 dense_1 (Dense)             (None, 17)                2193      
                                                                 
=================================================================
Total params: 9536945 (36.38 MB)
Trainable params: 9536945 (36.38 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________

五、编译模型

具体函数解释参考第八周博客或者K同学啊的博客!

1.设置动态学习率

# 设置初始学习率
initial_learning_rate = 1e-4

lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate, 
        decay_steps=60,      # 敲黑板!!!这里是指 steps,不是指epochs
        decay_rate=0.96,     # lr经过一次衰减就会变成 decay_rate*lr
        staircase=True)

# 将指数衰减学习率送入优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

model.compile(optimizer=optimizer,
              loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.compile(optimizer="adam",
              loss=tf.keras.losses.CategoricalCrossentropy(),
              metrics=['accuracy'])

2. 早停与保存最佳模型参数

EarlyStopping()参数说明:
monitor: 被监测的数据。
min_delta: 在被监测的数据中被认为是提升的最小变化, 例如,小于 min_delta 的绝对变化会被认为没有提升。
patience: 没有进步的训练轮数,在这之后训练就会被停止。
verbose: 详细信息模式。
mode: {auto, min, max} 其中之一。 在 min 模式中, 当被监测的数据停止下降,训练就会停止;在 max 模式中,当被监测的数据停止上升,训练就会停止;在 auto 模式中,方向会自动从被监测的数据的名字中判断出来。
baseline: 要监控的数量的基准值。 如果模型没有显示基准的改善,训练将停止。
estore_best_weights: 是否从具有监测数量的最佳值的时期恢复模型权重。 如果为 False,则使用在训练的最后一步获得的模型权重。

from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping

epochs = 100

# 保存最佳模型参数
checkpointer = ModelCheckpoint('best_model.h5',
                                monitor='val_accuracy',
                                verbose=1,
                                save_best_only=True,
                                save_weights_only=True)

# 设置早停
earlystopper = EarlyStopping(monitor='val_accuracy', 
                             min_delta=0.001,
                             patience=20, 
                             verbose=1)
                            

六、训练模型

history = model.fit(train_ds,
                    validation_data=val_ds,
                    epochs=epochs,
                    callbacks=[checkpointer, earlystopper])
Epoch 1/100
51/51 [==============================] - ETA: 0s - loss: 8.0135 - accuracy: 0.0556
Epoch 1: val_accuracy improved from -inf to 0.04444, saving model to best_model.h5
51/51 [==============================] - 39s 719ms/step - loss: 8.0135 - accuracy: 0.0556 - val_loss: 8.8650 - val_accuracy: 0.0444
Epoch 2/100
51/51 [==============================] - ETA: 0s - loss: 8.2083 - accuracy: 0.0568
Epoch 2: val_accuracy did not improve from 0.04444
51/51 [==============================] - 37s 728ms/step - loss: 8.2083 - accuracy: 0.0568 - val_loss: 7.4322 - val_accuracy: 0.0444
Epoch 3/100
51/51 [==============================] - ETA: 0s - loss: 8.1685 - accuracy: 0.0568
Epoch 3: val_accuracy did not improve from 0.04444
51/51 [==============================] - 36s 705ms/step - loss: 8.1685 - accuracy: 0.0568 - val_loss: 7.4322 - val_accuracy: 0.0444
Epoch 4/100
51/51 [==============================] - ETA: 0s - loss: 8.1386 - accuracy: 0.0568
Epoch 4: val_accuracy did not improve from 0.04444
51/51 [==============================] - 32s 627ms/step - loss: 8.1386 - accuracy: 0.0568 - val_loss: 7.4322 - val_accuracy: 0.0444
Epoch 5/100
51/51 [==============================] - ETA: 0s - loss: 8.1983 - accuracy: 0.0568
Epoch 5: val_accuracy did not improve from 0.04444
51/51 [==============================] - 36s 706ms/step - loss: 8.1983 - accuracy: 0.0568 - val_loss: 7.4322 - val_accuracy: 0.0444
Epoch 6/100
51/51 [==============================] - ETA: 0s - loss: 8.1784 - accuracy: 0.0568
Epoch 6: val_accuracy did not improve from 0.04444
51/51 [==============================] - 34s 669ms/step - loss: 8.1784 - accuracy: 0.0568 - val_loss: 7.4322 - val_accuracy: 0.0444
Epoch 7/100
51/51 [==============================] - ETA: 0s - loss: 8.1784 - accuracy: 0.0568
Epoch 7: val_accuracy did not improve from 0.04444
51/51 [==============================] - 32s 622ms/step - loss: 8.1784 - accuracy: 0.0568 - val_loss: 7.4322 - val_accuracy: 0.0444
Epoch 8/100
51/51 [==============================] - ETA: 0s - loss: 8.1884 - accuracy: 0.0568
Epoch 8: val_accuracy did not improve from 0.04444
51/51 [==============================] - 31s 610ms/step - loss: 8.1884 - accuracy: 0.0568 - val_loss: 7.4322 - val_accuracy: 0.0444
Epoch 9/100
51/51 [==============================] - ETA: 0s - loss: 8.1187 - accuracy: 0.0568
Epoch 9: val_accuracy did not improve from 0.04444
51/51 [==============================] - 31s 599ms/step - loss: 8.1187 - accuracy: 0.0568 - val_loss: 7.4322 - val_accuracy: 0.0444
Epoch 10/100
51/51 [==============================] - ETA: 0s - loss: 8.1486 - accuracy: 0.0568
Epoch 10: val_accuracy did not improve from 0.04444
51/51 [==============================] - 30s 590ms/step - loss: 8.1486 - accuracy: 0.0568 - val_loss: 7.4322 - val_accuracy: 0.0444
Epoch 11/100
51/51 [==============================] - ETA: 0s - loss: 8.1585 - accuracy: 0.0568
Epoch 11: val_accuracy did not improve from 0.04444
51/51 [==============================] - 29s 570ms/step - loss: 8.1585 - accuracy: 0.0568 - val_loss: 7.4322 - val_accuracy: 0.0444
Epoch 12/100
51/51 [==============================] - ETA: 0s - loss: 8.1884 - accuracy: 0.0568
Epoch 12: val_accuracy did not improve from 0.04444
51/51 [==============================] - 28s 559ms/step - loss: 8.1884 - accuracy: 0.0568 - val_loss: 7.4322 - val_accuracy: 0.0444
Epoch 13/100
51/51 [==============================] - ETA: 0s - loss: 8.1784 - accuracy: 0.0568
Epoch 13: val_accuracy did not improve from 0.04444
51/51 [==============================] - 28s 559ms/step - loss: 8.1784 - accuracy: 0.0568 - val_loss: 7.4322 - val_accuracy: 0.0444
Epoch 14/100
51/51 [==============================] - ETA: 0s - loss: 8.2182 - accuracy: 0.0568
Epoch 14: val_accuracy did not improve from 0.04444
51/51 [==============================] - 28s 558ms/step - loss: 8.2182 - accuracy: 0.0568 - val_loss: 7.4322 - val_accuracy: 0.0444
Epoch 15/100
51/51 [==============================] - ETA: 0s - loss: 8.2381 - accuracy: 0.0568
Epoch 15: val_accuracy did not improve from 0.04444
51/51 [==============================] - 29s 561ms/step - loss: 8.2381 - accuracy: 0.0568 - val_loss: 7.4322 - val_accuracy: 0.0444
Epoch 16/100
51/51 [==============================] - ETA: 0s - loss: 8.1187 - accuracy: 0.0568
Epoch 16: val_accuracy did not improve from 0.04444
51/51 [==============================] - 29s 560ms/step - loss: 8.1187 - accuracy: 0.0568 - val_loss: 7.4322 - val_accuracy: 0.0444
Epoch 17/100
51/51 [==============================] - ETA: 0s - loss: 8.1088 - accuracy: 0.0568
Epoch 17: val_accuracy did not improve from 0.04444
51/51 [==============================] - 29s 561ms/step - loss: 8.1088 - accuracy: 0.0568 - val_loss: 7.4322 - val_accuracy: 0.0444
Epoch 18/100
51/51 [==============================] - ETA: 0s - loss: 8.0988 - accuracy: 0.0568
Epoch 18: val_accuracy did not improve from 0.04444
51/51 [==============================] - 29s 562ms/step - loss: 8.0988 - accuracy: 0.0568 - val_loss: 7.4322 - val_accuracy: 0.0444
Epoch 19/100
51/51 [==============================] - ETA: 0s - loss: 8.1585 - accuracy: 0.0568
Epoch 19: val_accuracy did not improve from 0.04444
51/51 [==============================] - 29s 565ms/step - loss: 8.1585 - accuracy: 0.0568 - val_loss: 7.4322 - val_accuracy: 0.0444
Epoch 20/100
51/51 [==============================] - ETA: 0s - loss: 8.0392 - accuracy: 0.0568
Epoch 20: val_accuracy did not improve from 0.04444
51/51 [==============================] - 29s 567ms/step - loss: 8.0392 - accuracy: 0.0568 - val_loss: 7.4322 - val_accuracy: 0.0444
Epoch 21/100
51/51 [==============================] - ETA: 0s - loss: 8.1386 - accuracy: 0.0568
Epoch 21: val_accuracy did not improve from 0.04444
51/51 [==============================] - 29s 570ms/step - loss: 8.1386 - accuracy: 0.0568 - val_loss: 7.4322 - val_accuracy: 0.0444
Epoch 21: early stopping

七、预测

1、Accuracy图与Loss图

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(len(loss))

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

结果:
在这里插入图片描述

2、指定图片预测

# 加载效果最好的模型权重
model.load_weights('best_model.h5')

from PIL import Image
import numpy as np

img = Image.open(r"E:\Deep_Learning\Data\Week6\Tom Hanks\003_21d0aae6.jpg")  #这里选择你需要预测的图片
image = tf.image.resize(img, [img_height, img_width])

img_array = tf.expand_dims(image, 0) 

predictions = model.predict(img_array) # 这里选用你已经训练好的模型
print("预测结果为:",class_names[np.argmax(predictions)])

1/1 [==============================] - 0s 36ms/step
预测结果为: Tom Hanks

八、拓展

1. 修改初始动态学习率:

学习率设置为1e-5

在这里插入图片描述
🤔有一定提高

2.修改图片尺寸并修改参数

224224 → 256256,学习率仍为1e-5, patience为25。
在这里插入图片描述

3. 结论

这里我的acc最高达到了77.78%,除了降低初始学习率,还更改了图片尺寸,通过增加样本,让模型可以更好地泛化到新的数据;此外还修改了patience 值时,调整至25,此时模型会更容忍性能停滞的情况,训练会持续更长的时间。让其在一定程度上提高模型的性能,找到更好的参数配置。

  • 10
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
你好!对于心脏病预测的问题,使用循环神经网络(RNN)是一种常见的方法。RNN适用于处理序列数据,而心电图信号就是一种序列数据。在使用RNN进行心脏病预测时,你可以将心电图信号作为输入序列,然后通过训练RNN模型来预测患者是否患有心脏病。 首先,你需要准备一个合适的数据集,其中包含心电图信号和相应的心脏病标签。可以使用公开的心电图数据集,如PTB数据库或MIT-BIH数据库。然后,你可以对数据进行预处理和特征工程,如数据清洗、滤波、降采样等。 接下来,你可以构建一个RNN模型。RNN模型由一系列循环层组成,每个循环层都会处理一个时间步的输入数据。你可以选择不同类型的RNN单元,如简单循环单元(SimpleRNN)、长短期记忆网络(LSTM)或门控循环单元(GRU)。通过添加适当的全连接层和激活函数,你可以将RNN模型输出映射到二分类问题(有或无心脏病)的结果。 然后,你可以使用训练集对RNN模型进行训练,并使用验证集进行模型调优。在训练过程中,你可以使用适当的损失函数(如交叉熵)和优化算法(如随机梯度下降)来最小化模型的预测误差。 最后,你可以使用测试集对训练好的模型进行评估,并计算模型的性能指标,如准确率、精确率、召回率等。这些指标可以帮助你评估模型的预测能力和泛化能力。 需要注意的是,心脏病预测是一个复杂的医学问题,仅仅使用心电图信号可能不能得到准确的预测结果。通常情况下,还需要结合其他患者的临床信息和医学知识来进行综合评估。因此,在进行心脏病预测时,建议与专业医生合作,并遵循相关的医学准则和规范。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值