第T6周:使用TensorFlow实现好莱坞明星识别

电脑环境:
语言环境:Python 3.8.0
编译器:Jupyter Notebook
深度学习环境:tensorflow 2.15.0

一、前期工作

1.设置GPU(如果使用的是CPU可以忽略这步)

from tensorflow import keras
from keras import layers, models
import os, PIL, pathlib
import matplotlib.pyplot as plt
import tensorflow as tf

gpus = tf.config.list_physical_devices("GPU")

if gpus:
    gpu0 = gpus[0] #如果有多个GPU,仅使用第0个GPU
    tf.config.experimental.set_memory_growth(gpu0, True) #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpu0],"GPU")

2. 导入数据

data_dir = "./48-data/"
data_dir = pathlib.Path(data_dir)

3. 查看数据

image_count = len(list(data_dir.glob('*/*/*.jpg')))
print("图片总数为:",image_count)

输出:图片总数为: 1800

打开一张图片:

roses = list(data_dir.glob('Nicole Kidman/*.jpg'))
PIL.Image.open(str(roses[0]))

在这里插入图片描述

二、数据预处理

1、加载数据

使用image_dataset_from_directory方法将磁盘中的数据加载到tf.data.Dataset中。

batch_size = 32
img_height = 224
img_width = 224

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.1,
    subset="training",
    label_mode = "categorical",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)
    
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.1,
    subset="validation",
    label_mode = "categorical",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

我们可以通过class_names输出数据集的标签。标签将按字母顺序对应于目录名称。

class_names = train_ds.class_names
print(class_names)

输出:

[‘Angelina Jolie’, ‘Brad Pitt’, ‘Denzel Washington’, ‘Hugh Jackman’, ‘Jennifer Lawrence’, ‘Johnny Depp’, ‘Kate Winslet’, ‘Leonardo DiCaprio’, ‘Megan Fox’, ‘Natalie Portman’, ‘Nicole Kidman’, ‘Robert Downey Jr’, ‘Sandra Bullock’, ‘Scarlett Johansson’, ‘Tom Cruise’, ‘Tom Hanks’, ‘Will Smith’]

2、数据可视化

plt.figure(figsize=(20, 10))

for images, labels in train_ds.take(1):
    for i in range(20):
        ax = plt.subplot(5, 10, i + 1)

        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[np.argmax(labels[i])])
        
        plt.axis("off")

在这里插入图片描述

3、再次检查数据

for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break

输出:

(32, 224, 224, 3)
(32, 17)

4、配置数据集

AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

三、构建CNN网络

调用官方的VGG-16网络框架:

from keras.applications import VGG16

conv_base = VGG16(weights='imagenet',
                  include_top=False,
                  input_shape=(224, 224, 3))

加上全连接层:

model = models.Sequential()
model.add(conv_base)
model.add(layers.Flatten())
model.add(layers.Dense(256, activation='relu'))
model.add(layers.Dropout(0.4))
model.add(layers.Dense(len(class_names)))
model.summary()

网络详情:

_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 vgg16 (Functional)          (None, 7, 7, 512)         14714688  
                                                                 
 flatten (Flatten)           (None, 25088)             0         
                                                                 
 dense (Dense)               (None, 256)               6422784   
                                                                 
 dropout (Dropout)           (None, 256)               0         
                                                                 
 dense_1 (Dense)             (None, 17)                4369      
                                                                 
=================================================================
Total params: 21141841 (80.65 MB)
Trainable params: 21141841 (80.65 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________

四、训练模型

1、设置动态学习率

# 设置初始学习率
initial_learning_rate = 1e-4

lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate,
    decay_steps=20,
    decay_rate=0.96,
    staircase=True
)

optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

model.compile(optimizer=optimizer,
              loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

2、早停与保存最佳模型参数

from keras.callbacks import ModelCheckpoint, EarlyStopping

epochs = 100

# 保存最佳模型参数
checkpointer = ModelCheckpoint('best_model.h5',
                                monitor='val_accuracy',
                                verbose=1,
                                save_best_only=True,
                                save_weights_only=True)

# 设置早停
earlystopper = EarlyStopping(monitor='val_accuracy',
                             min_delta=0.001,
                             patience=20,
                             verbose=1)

3、模型训练

history = model.fit(train_ds,
                    validation_data=val_ds,
                    epochs=epochs,
                    callbacks=[checkpointer, earlystopper])

输出:

Epoch 1/100
45/45 [==============================] - ETA: 0s - loss: 3.8728 - accuracy: 0.0639
Epoch 1: val_accuracy improved from -inf to 0.05833, saving model to /content/drive/MyDrive/app/T6/best_model.h5
45/45 [==============================] - 392s 2s/step - loss: 3.8728 - accuracy: 0.0639 - val_loss: 2.8329 - val_accuracy: 0.0583
Epoch 2/100
45/45 [==============================] - ETA: 0s - loss: 2.8281 - accuracy: 0.0750
Epoch 2: val_accuracy improved from 0.05833 to 0.08611, saving model to /content/drive/MyDrive/app/T6/best_model.h5
45/45 [==============================] - 20s 435ms/step - loss: 2.8281 - accuracy: 0.0750 - val_loss: 2.8312 - val_accuracy: 0.0861
Epoch 3/100
45/45 [==============================] - ETA: 0s - loss: 2.8119 - accuracy: 0.0986
Epoch 3: val_accuracy improved from 0.08611 to 0.12500, saving model to /content/drive/MyDrive/app/T6/best_model.h5
45/45 [==============================] - 20s 441ms/step - loss: 2.8119 - accuracy: 0.0986 - val_loss: 2.7734 - val_accuracy: 0.1250
Epoch 4/100
45/45 [==============================] - ETA: 0s - loss: 2.7812 - accuracy: 0.1153
Epoch 4: val_accuracy improved from 0.12500 to 0.13333, saving model to /content/drive/MyDrive/app/T6/best_model.h5
45/45 [==============================] - 20s 444ms/step - loss: 2.7812 - accuracy: 0.1153 - val_loss: 2.7716 - val_accuracy: 0.1333
Epoch 5/100
45/45 [==============================] - ETA: 0s - loss: 2.7444 - accuracy: 0.1319
Epoch 5: val_accuracy improved from 0.13333 to 0.13889, saving model to /content/drive/MyDrive/app/T6/best_model.h5
45/45 [==============================] - 20s 455ms/step - loss: 2.7444 - accuracy: 0.1319 - val_loss: 2.7138 - val_accuracy: 0.1389
Epoch 6/100
45/45 [==============================] - ETA: 0s - loss: 2.6869 - accuracy: 0.1458
Epoch 6: val_accuracy improved from 0.13889 to 0.15556, saving model to /content/drive/MyDrive/app/T6/best_model.h5
45/45 [==============================] - 22s 493ms/step - loss: 2.6869 - accuracy: 0.1458 - val_loss: 2.6313 - val_accuracy: 0.1556
Epoch 7/100
45/45 [==============================] - ETA: 0s - loss: 2.6230 - accuracy: 0.1507
Epoch 7: val_accuracy did not improve from 0.15556
45/45 [==============================] - 19s 426ms/step - loss: 2.6230 - accuracy: 0.1507 - val_loss: 2.6150 - val_accuracy: 0.1500
Epoch 8/100
45/45 [==============================] - ETA: 0s - loss: 2.5483 - accuracy: 0.1826
Epoch 8: val_accuracy did not improve from 0.15556
45/45 [==============================] - 19s 431ms/step - loss: 2.5483 - accuracy: 0.1826 - val_loss: 2.6042 - val_accuracy: 0.1361
Epoch 9/100
45/45 [==============================] - ETA: 0s - loss: 2.5030 - accuracy: 0.1840
Epoch 9: val_accuracy improved from 0.15556 to 0.16389, saving model to /content/drive/MyDrive/app/T6/best_model.h5
45/45 [==============================] - 21s 474ms/step - loss: 2.5030 - accuracy: 0.1840 - val_loss: 2.5547 - val_accuracy: 0.1639
Epoch 10/100
45/45 [==============================] - ETA: 0s - loss: 2.4198 - accuracy: 0.2208
Epoch 10: val_accuracy improved from 0.16389 to 0.18056, saving model to /content/drive/MyDrive/app/T6/best_model.h5
45/45 [==============================] - 21s 471ms/step - loss: 2.4198 - accuracy: 0.2208 - val_loss: 2.5333 - val_accuracy: 0.1806
Epoch 11/100
45/45 [==============================] - ETA: 0s - loss: 2.3769 - accuracy: 0.2222
Epoch 11: val_accuracy improved from 0.18056 to 0.21944, saving model to /content/drive/MyDrive/app/T6/best_model.h5
45/45 [==============================] - 21s 466ms/step - loss: 2.3769 - accuracy: 0.2222 - val_loss: 2.4761 - val_accuracy: 0.2194
Epoch 12/100
45/45 [==============================] - ETA: 0s - loss: 2.2730 - accuracy: 0.2583
Epoch 12: val_accuracy did not improve from 0.21944
45/45 [==============================] - 20s 442ms/step - loss: 2.2730 - accuracy: 0.2583 - val_loss: 2.4136 - val_accuracy: 0.2194
Epoch 13/100
45/45 [==============================] - ETA: 0s - loss: 2.1822 - accuracy: 0.2944
Epoch 13: val_accuracy did not improve from 0.21944
45/45 [==============================] - 20s 444ms/step - loss: 2.1822 - accuracy: 0.2944 - val_loss: 2.3941 - val_accuracy: 0.2056
Epoch 14/100
45/45 [==============================] - ETA: 0s - loss: 2.0759 - accuracy: 0.3229
Epoch 14: val_accuracy improved from 0.21944 to 0.26389, saving model to /content/drive/MyDrive/app/T6/best_model.h5
45/45 [==============================] - 22s 481ms/step - loss: 2.0759 - accuracy: 0.3229 - val_loss: 2.3304 - val_accuracy: 0.2639
Epoch 15/100
45/45 [==============================] - ETA: 0s - loss: 1.9345 - accuracy: 0.3542
Epoch 15: val_accuracy did not improve from 0.26389
45/45 [==============================] - 20s 444ms/step - loss: 1.9345 - accuracy: 0.3542 - val_loss: 2.2658 - val_accuracy: 0.2556
Epoch 16/100
45/45 [==============================] - ETA: 0s - loss: 1.8489 - accuracy: 0.3743
Epoch 16: val_accuracy improved from 0.26389 to 0.29444, saving model to /content/drive/MyDrive/app/T6/best_model.h5
45/45 [==============================] - 22s 482ms/step - loss: 1.8489 - accuracy: 0.3743 - val_loss: 2.2204 - val_accuracy: 0.2944
Epoch 17/100
45/45 [==============================] - ETA: 0s - loss: 1.7346 - accuracy: 0.4229
Epoch 17: val_accuracy did not improve from 0.29444
45/45 [==============================] - 20s 444ms/step - loss: 1.7346 - accuracy: 0.4229 - val_loss: 2.2433 - val_accuracy: 0.2778
Epoch 18/100
45/45 [==============================] - ETA: 0s - loss: 1.5925 - accuracy: 0.4847
Epoch 18: val_accuracy improved from 0.29444 to 0.30278, saving model to /content/drive/MyDrive/app/T6/best_model.h5
45/45 [==============================] - 22s 488ms/step - loss: 1.5925 - accuracy: 0.4847 - val_loss: 2.2539 - val_accuracy: 0.3028
Epoch 19/100
45/45 [==============================] - ETA: 0s - loss: 1.5517 - accuracy: 0.4938
Epoch 19: val_accuracy did not improve from 0.30278
45/45 [==============================] - 20s 443ms/step - loss: 1.5517 - accuracy: 0.4938 - val_loss: 2.2492 - val_accuracy: 0.2972
Epoch 20/100
45/45 [==============================] - ETA: 0s - loss: 1.3983 - accuracy: 0.5285
Epoch 20: val_accuracy did not improve from 0.30278
45/45 [==============================] - 20s 451ms/step - loss: 1.3983 - accuracy: 0.5285 - val_loss: 2.2970 - val_accuracy: 0.3028
Epoch 21/100
45/45 [==============================] - ETA: 0s - loss: 1.2226 - accuracy: 0.5840
Epoch 21: val_accuracy improved from 0.30278 to 0.31667, saving model to /content/drive/MyDrive/app/T6/best_model.h5
45/45 [==============================] - 23s 507ms/step - loss: 1.2226 - accuracy: 0.5840 - val_loss: 2.2511 - val_accuracy: 0.3167
Epoch 22/100
45/45 [==============================] - ETA: 0s - loss: 1.1661 - accuracy: 0.6201
Epoch 22: val_accuracy did not improve from 0.31667
45/45 [==============================] - 21s 464ms/step - loss: 1.1661 - accuracy: 0.6201 - val_loss: 2.3456 - val_accuracy: 0.3000
Epoch 23/100
45/45 [==============================] - ETA: 0s - loss: 1.0238 - accuracy: 0.6590
Epoch 23: val_accuracy did not improve from 0.31667
45/45 [==============================] - 20s 436ms/step - loss: 1.0238 - accuracy: 0.6590 - val_loss: 2.4425 - val_accuracy: 0.3083
Epoch 24/100
45/45 [==============================] - ETA: 0s - loss: 0.9449 - accuracy: 0.6868
Epoch 24: val_accuracy improved from 0.31667 to 0.32778, saving model to /content/drive/MyDrive/app/T6/best_model.h5
45/45 [==============================] - 22s 486ms/step - loss: 0.9449 - accuracy: 0.6868 - val_loss: 2.3126 - val_accuracy: 0.3278
Epoch 25/100
45/45 [==============================] - ETA: 0s - loss: 0.8259 - accuracy: 0.7229
Epoch 25: val_accuracy did not improve from 0.32778
45/45 [==============================] - 20s 434ms/step - loss: 0.8259 - accuracy: 0.7229 - val_loss: 2.3506 - val_accuracy: 0.3000
Epoch 26/100
45/45 [==============================] - ETA: 0s - loss: 0.7882 - accuracy: 0.7333
Epoch 26: val_accuracy did not improve from 0.32778
45/45 [==============================] - 20s 438ms/step - loss: 0.7882 - accuracy: 0.7333 - val_loss: 2.3976 - val_accuracy: 0.3083
Epoch 27/100
45/45 [==============================] - ETA: 0s - loss: 0.6816 - accuracy: 0.7806
Epoch 27: val_accuracy did not improve from 0.32778
45/45 [==============================] - 20s 441ms/step - loss: 0.6816 - accuracy: 0.7806 - val_loss: 2.5215 - val_accuracy: 0.3167
Epoch 28/100
45/45 [==============================] - ETA: 0s - loss: 0.6466 - accuracy: 0.7931
Epoch 28: val_accuracy did not improve from 0.32778
45/45 [==============================] - 20s 444ms/step - loss: 0.6466 - accuracy: 0.7931 - val_loss: 2.4860 - val_accuracy: 0.3194
Epoch 29/100
45/45 [==============================] - ETA: 0s - loss: 0.5820 - accuracy: 0.8062
Epoch 29: val_accuracy did not improve from 0.32778
45/45 [==============================] - 20s 446ms/step - loss: 0.5820 - accuracy: 0.8062 - val_loss: 2.4623 - val_accuracy: 0.3194
Epoch 30/100
45/45 [==============================] - ETA: 0s - loss: 0.5293 - accuracy: 0.8292
Epoch 30: val_accuracy did not improve from 0.32778
45/45 [==============================] - 20s 445ms/step - loss: 0.5293 - accuracy: 0.8292 - val_loss: 2.5641 - val_accuracy: 0.3222
Epoch 31/100
45/45 [==============================] - ETA: 0s - loss: 0.4870 - accuracy: 0.8403
Epoch 31: val_accuracy did not improve from 0.32778
45/45 [==============================] - 20s 443ms/step - loss: 0.4870 - accuracy: 0.8403 - val_loss: 2.6284 - val_accuracy: 0.3250
Epoch 32/100
45/45 [==============================] - ETA: 0s - loss: 0.4601 - accuracy: 0.8507
Epoch 32: val_accuracy improved from 0.32778 to 0.33889, saving model to /content/drive/MyDrive/app/T6/best_model.h5
45/45 [==============================] - 22s 482ms/step - loss: 0.4601 - accuracy: 0.8507 - val_loss: 2.6226 - val_accuracy: 0.3389
Epoch 33/100
45/45 [==============================] - ETA: 0s - loss: 0.4220 - accuracy: 0.8708
Epoch 33: val_accuracy did not improve from 0.33889
45/45 [==============================] - 20s 446ms/step - loss: 0.4220 - accuracy: 0.8708 - val_loss: 2.6235 - val_accuracy: 0.3278
Epoch 34/100
45/45 [==============================] - ETA: 0s - loss: 0.4048 - accuracy: 0.8708
Epoch 34: val_accuracy improved from 0.33889 to 0.34444, saving model to /content/drive/MyDrive/app/T6/best_model.h5
45/45 [==============================] - 22s 492ms/step - loss: 0.4048 - accuracy: 0.8708 - val_loss: 2.5722 - val_accuracy: 0.3444
Epoch 35/100
45/45 [==============================] - ETA: 0s - loss: 0.3656 - accuracy: 0.8889
Epoch 35: val_accuracy did not improve from 0.34444
45/45 [==============================] - 21s 469ms/step - loss: 0.3656 - accuracy: 0.8889 - val_loss: 2.7488 - val_accuracy: 0.3389
Epoch 36/100
45/45 [==============================] - ETA: 0s - loss: 0.3666 - accuracy: 0.8889
Epoch 36: val_accuracy did not improve from 0.34444
45/45 [==============================] - 21s 463ms/step - loss: 0.3666 - accuracy: 0.8889 - val_loss: 2.7453 - val_accuracy: 0.3278
Epoch 37/100
45/45 [==============================] - ETA: 0s - loss: 0.3088 - accuracy: 0.9076
Epoch 37: val_accuracy improved from 0.34444 to 0.35000, saving model to /content/drive/MyDrive/app/T6/best_model.h5
45/45 [==============================] - 21s 469ms/step - loss: 0.3088 - accuracy: 0.9076 - val_loss: 2.7392 - val_accuracy: 0.3500
Epoch 38/100
45/45 [==============================] - ETA: 0s - loss: 0.3152 - accuracy: 0.8986
Epoch 38: val_accuracy improved from 0.35000 to 0.35278, saving model to /content/drive/MyDrive/app/T6/best_model.h5
45/45 [==============================] - 21s 467ms/step - loss: 0.3152 - accuracy: 0.8986 - val_loss: 2.7681 - val_accuracy: 0.3528
Epoch 39/100
45/45 [==============================] - ETA: 0s - loss: 0.2952 - accuracy: 0.9111
Epoch 39: val_accuracy did not improve from 0.35278
45/45 [==============================] - 20s 442ms/step - loss: 0.2952 - accuracy: 0.9111 - val_loss: 2.7665 - val_accuracy: 0.3417
Epoch 40/100
45/45 [==============================] - ETA: 0s - loss: 0.2937 - accuracy: 0.9056
Epoch 40: val_accuracy did not improve from 0.35278
45/45 [==============================] - 20s 444ms/step - loss: 0.2937 - accuracy: 0.9056 - val_loss: 2.8206 - val_accuracy: 0.3417
Epoch 41/100
45/45 [==============================] - ETA: 0s - loss: 0.2608 - accuracy: 0.9160
Epoch 41: val_accuracy improved from 0.35278 to 0.35556, saving model to /content/drive/MyDrive/app/T6/best_model.h5
45/45 [==============================] - 22s 489ms/step - loss: 0.2608 - accuracy: 0.9160 - val_loss: 2.8521 - val_accuracy: 0.3556
Epoch 42/100
45/45 [==============================] - ETA: 0s - loss: 0.2588 - accuracy: 0.9167
Epoch 42: val_accuracy did not improve from 0.35556
45/45 [==============================] - 20s 445ms/step - loss: 0.2588 - accuracy: 0.9167 - val_loss: 2.8687 - val_accuracy: 0.3361
Epoch 43/100
45/45 [==============================] - ETA: 0s - loss: 0.2643 - accuracy: 0.9153
Epoch 43: val_accuracy did not improve from 0.35556
45/45 [==============================] - 20s 446ms/step - loss: 0.2643 - accuracy: 0.9153 - val_loss: 2.8563 - val_accuracy: 0.3306
Epoch 44/100
45/45 [==============================] - ETA: 0s - loss: 0.2337 - accuracy: 0.9326
Epoch 44: val_accuracy did not improve from 0.35556
45/45 [==============================] - 21s 471ms/step - loss: 0.2337 - accuracy: 0.9326 - val_loss: 2.8820 - val_accuracy: 0.3417
Epoch 45/100
45/45 [==============================] - ETA: 0s - loss: 0.2270 - accuracy: 0.9264
Epoch 45: val_accuracy improved from 0.35556 to 0.35833, saving model to /content/drive/MyDrive/app/T6/best_model.h5
45/45 [==============================] - 22s 485ms/step - loss: 0.2270 - accuracy: 0.9264 - val_loss: 2.9108 - val_accuracy: 0.3583
Epoch 46/100
45/45 [==============================] - ETA: 0s - loss: 0.2310 - accuracy: 0.9382
Epoch 46: val_accuracy did not improve from 0.35833
45/45 [==============================] - 20s 447ms/step - loss: 0.2310 - accuracy: 0.9382 - val_loss: 2.8827 - val_accuracy: 0.3556
Epoch 47/100
45/45 [==============================] - ETA: 0s - loss: 0.2290 - accuracy: 0.9347
Epoch 47: val_accuracy did not improve from 0.35833
45/45 [==============================] - 20s 447ms/step - loss: 0.2290 - accuracy: 0.9347 - val_loss: 2.8759 - val_accuracy: 0.3528
Epoch 48/100
45/45 [==============================] - ETA: 0s - loss: 0.2132 - accuracy: 0.9361
Epoch 48: val_accuracy did not improve from 0.35833
45/45 [==============================] - 20s 447ms/step - loss: 0.2132 - accuracy: 0.9361 - val_loss: 2.8647 - val_accuracy: 0.3472
Epoch 49/100
45/45 [==============================] - ETA: 0s - loss: 0.2449 - accuracy: 0.9243
Epoch 49: val_accuracy did not improve from 0.35833
45/45 [==============================] - 20s 445ms/step - loss: 0.2449 - accuracy: 0.9243 - val_loss: 2.8989 - val_accuracy: 0.3333
Epoch 50/100
45/45 [==============================] - ETA: 0s - loss: 0.2371 - accuracy: 0.9229
Epoch 50: val_accuracy did not improve from 0.35833
45/45 [==============================] - 20s 445ms/step - loss: 0.2371 - accuracy: 0.9229 - val_loss: 2.8993 - val_accuracy: 0.3361
Epoch 51/100
45/45 [==============================] - ETA: 0s - loss: 0.2011 - accuracy: 0.9458
Epoch 51: val_accuracy did not improve from 0.35833
45/45 [==============================] - 20s 445ms/step - loss: 0.2011 - accuracy: 0.9458 - val_loss: 2.8976 - val_accuracy: 0.3389
Epoch 52/100
45/45 [==============================] - ETA: 0s - loss: 0.2190 - accuracy: 0.9347
Epoch 52: val_accuracy did not improve from 0.35833
45/45 [==============================] - 20s 445ms/step - loss: 0.2190 - accuracy: 0.9347 - val_loss: 2.9062 - val_accuracy: 0.3417
Epoch 53/100
45/45 [==============================] - ETA: 0s - loss: 0.2196 - accuracy: 0.9312
Epoch 53: val_accuracy did not improve from 0.35833
45/45 [==============================] - 20s 445ms/step - loss: 0.2196 - accuracy: 0.9312 - val_loss: 2.9152 - val_accuracy: 0.3389
Epoch 54/100
45/45 [==============================] - ETA: 0s - loss: 0.2086 - accuracy: 0.9417
Epoch 54: val_accuracy did not improve from 0.35833
45/45 [==============================] - 20s 445ms/step - loss: 0.2086 - accuracy: 0.9417 - val_loss: 2.8989 - val_accuracy: 0.3472
Epoch 55/100
45/45 [==============================] - ETA: 0s - loss: 0.2074 - accuracy: 0.9417
Epoch 55: val_accuracy did not improve from 0.35833
45/45 [==============================] - 20s 446ms/step - loss: 0.2074 - accuracy: 0.9417 - val_loss: 2.9394 - val_accuracy: 0.3444
Epoch 56/100
45/45 [==============================] - ETA: 0s - loss: 0.2061 - accuracy: 0.9410
Epoch 56: val_accuracy did not improve from 0.35833
45/45 [==============================] - 20s 446ms/step - loss: 0.2061 - accuracy: 0.9410 - val_loss: 2.9220 - val_accuracy: 0.3389
Epoch 57/100
45/45 [==============================] - ETA: 0s - loss: 0.1886 - accuracy: 0.9514
Epoch 57: val_accuracy did not improve from 0.35833
45/45 [==============================] - 20s 446ms/step - loss: 0.1886 - accuracy: 0.9514 - val_loss: 2.9286 - val_accuracy: 0.3333
Epoch 58/100
45/45 [==============================] - ETA: 0s - loss: 0.1918 - accuracy: 0.9472
Epoch 58: val_accuracy did not improve from 0.35833
45/45 [==============================] - 20s 446ms/step - loss: 0.1918 - accuracy: 0.9472 - val_loss: 2.9408 - val_accuracy: 0.3333
Epoch 59/100
45/45 [==============================] - ETA: 0s - loss: 0.1918 - accuracy: 0.9424
Epoch 59: val_accuracy did not improve from 0.35833
45/45 [==============================] - 21s 472ms/step - loss: 0.1918 - accuracy: 0.9424 - val_loss: 2.9616 - val_accuracy: 0.3361
Epoch 60/100
45/45 [==============================] - ETA: 0s - loss: 0.1911 - accuracy: 0.9424
Epoch 60: val_accuracy did not improve from 0.35833
45/45 [==============================] - 20s 444ms/step - loss: 0.1911 - accuracy: 0.9424 - val_loss: 2.9689 - val_accuracy: 0.3333
Epoch 61/100
45/45 [==============================] - ETA: 0s - loss: 0.1828 - accuracy: 0.9458
Epoch 61: val_accuracy did not improve from 0.35833
45/45 [==============================] - 20s 447ms/step - loss: 0.1828 - accuracy: 0.9458 - val_loss: 2.9638 - val_accuracy: 0.3333
Epoch 62/100
45/45 [==============================] - ETA: 0s - loss: 0.1862 - accuracy: 0.9396
Epoch 62: val_accuracy did not improve from 0.35833
45/45 [==============================] - 20s 446ms/step - loss: 0.1862 - accuracy: 0.9396 - val_loss: 2.9723 - val_accuracy: 0.3333
Epoch 63/100
45/45 [==============================] - ETA: 0s - loss: 0.1807 - accuracy: 0.9424
Epoch 63: val_accuracy did not improve from 0.35833
45/45 [==============================] - 20s 446ms/step - loss: 0.1807 - accuracy: 0.9424 - val_loss: 2.9732 - val_accuracy: 0.3361
Epoch 64/100
45/45 [==============================] - ETA: 0s - loss: 0.1731 - accuracy: 0.9542
Epoch 64: val_accuracy did not improve from 0.35833
45/45 [==============================] - 20s 445ms/step - loss: 0.1731 - accuracy: 0.9542 - val_loss: 2.9762 - val_accuracy: 0.3389
Epoch 65/100
45/45 [==============================] - ETA: 0s - loss: 0.1733 - accuracy: 0.9576
Epoch 65: val_accuracy did not improve from 0.35833
45/45 [==============================] - 20s 445ms/step - loss: 0.1733 - accuracy: 0.9576 - val_loss: 2.9801 - val_accuracy: 0.3333
Epoch 65: early stopping

五、模型评估

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(len(loss))

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

在这里插入图片描述
模型性能很差,接下来尝试优化。

六、模型优化

1、冻结卷积基,只训练全连接层的参数

from keras.applications import VGG16
from keras.models import Sequential
from keras.layers import Dense, Flatten, Dropout, Conv2D, MaxPool2D

vgg_model = VGG16(weights='imagenet', include_top=False, input_shape=(img_height, img_width, 3))


model = models.Sequential()
model.add(vgg_model)
model.add(layers.Flatten())
model.add(layers.Dense(256, activation='relu'))
model.add(layers.Dropout(0.5))
model.add(layers.Dense(len(class_names), activation='softmax'))

在 Keras 中,冻结网络的方法是将其 trainable 属性设为 False。
首先查看有多少个权重张量:

print('This is the number of trainable weights before freezing the conv base:', len(model.trainable_weights))

输出:

This is the number of trainable weights before freezing the conv base: 30
一共13个conv层和2个Dense层,每层两个权重张量(主权重矩阵和偏置向量)

将vgg_model的trainable 属性设为 False:

# 冻结卷基层
vgg_model.trainable = False

再次查看:

print('This is the number of trainable weights after freezing the conv base:', len(model.trainable_weights))

输出:

This is the number of trainable weights after freezing the conv base: 4

如此设置之后,只有添加的两个 Dense 层的权重才会被训练。

2、设置lr、ModelCheckpoint、EarlyStopping

和之前一样。

# 设置初始学习率
initial_learning_rate = 1e-4

lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate,
    decay_steps=60,
    decay_rate=0.96,
    staircase=True
)

optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

model.compile(optimizer=optimizer,
              loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
from keras.callbacks import ModelCheckpoint, EarlyStopping

epochs = 100

# 保存最佳模型参数
checkpointer = ModelCheckpoint('best_model.h5',
                                monitor='val_accuracy',
                                verbose=1,
                                save_best_only=True,
                                save_weights_only=True)

# 设置早停
earlystopper = EarlyStopping(monitor='val_accuracy',
                             min_delta=0.001,
                             patience=20,
                             verbose=1)

3、再次训练

history = model.fit(train_ds,
                    validation_data=val_ds,
                    epochs=epochs,
                    callbacks=[checkpointer, earlystopper])

输出:

Epoch 1/100
51/51 [==============================] - ETA: 0s - loss: 15.5740 - accuracy: 0.1074
Epoch 1: val_accuracy improved from -inf to 0.35000, saving model to /content/drive/MyDrive/app/T6/best_model2.h5
51/51 [==============================] - 293s 845ms/step - loss: 15.5740 - accuracy: 0.1074 - val_loss: 2.6472 - val_accuracy: 0.3500
Epoch 2/100
51/51 [==============================] - ETA: 0s - loss: 2.6200 - accuracy: 0.2994
Epoch 2: val_accuracy improved from 0.35000 to 0.43611, saving model to /content/drive/MyDrive/app/T6/best_model2.h5
51/51 [==============================] - 9s 173ms/step - loss: 2.6200 - accuracy: 0.2994 - val_loss: 1.9987 - val_accuracy: 0.4361
Epoch 3/100
51/51 [==============================] - ETA: 0s - loss: 1.8993 - accuracy: 0.4179
Epoch 3: val_accuracy improved from 0.43611 to 0.52778, saving model to /content/drive/MyDrive/app/T6/best_model2.h5
51/51 [==============================] - 8s 159ms/step - loss: 1.8993 - accuracy: 0.4179 - val_loss: 1.7028 - val_accuracy: 0.5278
Epoch 4/100
51/51 [==============================] - ETA: 0s - loss: 1.6321 - accuracy: 0.4772
Epoch 4: val_accuracy improved from 0.52778 to 0.56944, saving model to /content/drive/MyDrive/app/T6/best_model2.h5
51/51 [==============================] - 10s 197ms/step - loss: 1.6321 - accuracy: 0.4772 - val_loss: 1.6052 - val_accuracy: 0.5694
Epoch 5/100
51/51 [==============================] - ETA: 0s - loss: 1.4345 - accuracy: 0.5136
Epoch 5: val_accuracy improved from 0.56944 to 0.59722, saving model to /content/drive/MyDrive/app/T6/best_model2.h5
51/51 [==============================] - 9s 186ms/step - loss: 1.4345 - accuracy: 0.5136 - val_loss: 1.5235 - val_accuracy: 0.5972
Epoch 6/100
51/51 [==============================] - ETA: 0s - loss: 1.2427 - accuracy: 0.5537
Epoch 6: val_accuracy improved from 0.59722 to 0.62222, saving model to /content/drive/MyDrive/app/T6/best_model2.h5
51/51 [==============================] - 9s 180ms/step - loss: 1.2427 - accuracy: 0.5537 - val_loss: 1.5024 - val_accuracy: 0.6222
Epoch 7/100
51/51 [==============================] - ETA: 0s - loss: 1.0806 - accuracy: 0.6130
Epoch 7: val_accuracy improved from 0.62222 to 0.62778, saving model to /content/drive/MyDrive/app/T6/best_model2.h5
51/51 [==============================] - 10s 190ms/step - loss: 1.0806 - accuracy: 0.6130 - val_loss: 1.4337 - val_accuracy: 0.6278
Epoch 8/100
51/51 [==============================] - ETA: 0s - loss: 0.9529 - accuracy: 0.6611
Epoch 8: val_accuracy improved from 0.62778 to 0.65833, saving model to /content/drive/MyDrive/app/T6/best_model2.h5
51/51 [==============================] - 9s 186ms/step - loss: 0.9529 - accuracy: 0.6611 - val_loss: 1.3828 - val_accuracy: 0.6583
Epoch 9/100
51/51 [==============================] - ETA: 0s - loss: 0.8889 - accuracy: 0.6716
Epoch 9: val_accuracy did not improve from 0.65833
51/51 [==============================] - 9s 186ms/step - loss: 0.8889 - accuracy: 0.6716 - val_loss: 1.3703 - val_accuracy: 0.6500
Epoch 10/100
51/51 [==============================] - ETA: 0s - loss: 0.7911 - accuracy: 0.7123
Epoch 10: val_accuracy improved from 0.65833 to 0.67222, saving model to /content/drive/MyDrive/app/T6/best_model2.h5
51/51 [==============================] - 9s 181ms/step - loss: 0.7911 - accuracy: 0.7123 - val_loss: 1.3866 - val_accuracy: 0.6722
Epoch 11/100
51/51 [==============================] - ETA: 0s - loss: 0.7064 - accuracy: 0.7543
Epoch 11: val_accuracy improved from 0.67222 to 0.68333, saving model to /content/drive/MyDrive/app/T6/best_model2.h5
51/51 [==============================] - 8s 167ms/step - loss: 0.7064 - accuracy: 0.7543 - val_loss: 1.3614 - val_accuracy: 0.6833
Epoch 12/100
51/51 [==============================] - ETA: 0s - loss: 0.6766 - accuracy: 0.7593
Epoch 12: val_accuracy did not improve from 0.68333
51/51 [==============================] - 8s 161ms/step - loss: 0.6766 - accuracy: 0.7593 - val_loss: 1.3141 - val_accuracy: 0.6806
Epoch 13/100
51/51 [==============================] - ETA: 0s - loss: 0.6331 - accuracy: 0.7728
Epoch 13: val_accuracy did not improve from 0.68333
51/51 [==============================] - 8s 160ms/step - loss: 0.6331 - accuracy: 0.7728 - val_loss: 1.3756 - val_accuracy: 0.6778
Epoch 14/100
51/51 [==============================] - ETA: 0s - loss: 0.5838 - accuracy: 0.7877
Epoch 14: val_accuracy improved from 0.68333 to 0.69167, saving model to /content/drive/MyDrive/app/T6/best_model2.h5
51/51 [==============================] - 9s 179ms/step - loss: 0.5838 - accuracy: 0.7877 - val_loss: 1.3946 - val_accuracy: 0.6917
Epoch 15/100
51/51 [==============================] - ETA: 0s - loss: 0.5148 - accuracy: 0.8123
Epoch 15: val_accuracy did not improve from 0.69167
51/51 [==============================] - 8s 160ms/step - loss: 0.5148 - accuracy: 0.8123 - val_loss: 1.4108 - val_accuracy: 0.6778
Epoch 16/100
51/51 [==============================] - ETA: 0s - loss: 0.4900 - accuracy: 0.8062
Epoch 16: val_accuracy did not improve from 0.69167
51/51 [==============================] - 8s 161ms/step - loss: 0.4900 - accuracy: 0.8062 - val_loss: 1.3512 - val_accuracy: 0.6889
Epoch 17/100
51/51 [==============================] - ETA: 0s - loss: 0.4834 - accuracy: 0.8173
Epoch 17: val_accuracy did not improve from 0.69167
51/51 [==============================] - 8s 161ms/step - loss: 0.4834 - accuracy: 0.8173 - val_loss: 1.3588 - val_accuracy: 0.6889
Epoch 18/100
51/51 [==============================] - ETA: 0s - loss: 0.4556 - accuracy: 0.8278
Epoch 18: val_accuracy improved from 0.69167 to 0.70000, saving model to /content/drive/MyDrive/app/T6/best_model2.h5
51/51 [==============================] - 10s 202ms/step - loss: 0.4556 - accuracy: 0.8278 - val_loss: 1.3768 - val_accuracy: 0.7000
Epoch 19/100
51/51 [==============================] - ETA: 0s - loss: 0.4350 - accuracy: 0.8457
Epoch 19: val_accuracy improved from 0.70000 to 0.70556, saving model to /content/drive/MyDrive/app/T6/best_model2.h5
51/51 [==============================] - 9s 167ms/step - loss: 0.4350 - accuracy: 0.8457 - val_loss: 1.3625 - val_accuracy: 0.7056
Epoch 20/100
51/51 [==============================] - ETA: 0s - loss: 0.3907 - accuracy: 0.8667
Epoch 20: val_accuracy improved from 0.70556 to 0.70833, saving model to /content/drive/MyDrive/app/T6/best_model2.h5
51/51 [==============================] - 9s 181ms/step - loss: 0.3907 - accuracy: 0.8667 - val_loss: 1.3188 - val_accuracy: 0.7083
Epoch 21/100
51/51 [==============================] - ETA: 0s - loss: 0.4094 - accuracy: 0.8414
Epoch 21: val_accuracy did not improve from 0.70833
51/51 [==============================] - 8s 162ms/step - loss: 0.4094 - accuracy: 0.8414 - val_loss: 1.3284 - val_accuracy: 0.7056
Epoch 22/100
51/51 [==============================] - ETA: 0s - loss: 0.3572 - accuracy: 0.8642
Epoch 22: val_accuracy did not improve from 0.70833
51/51 [==============================] - 9s 182ms/step - loss: 0.3572 - accuracy: 0.8642 - val_loss: 1.3724 - val_accuracy: 0.7083
Epoch 23/100
51/51 [==============================] - ETA: 0s - loss: 0.3208 - accuracy: 0.8827
Epoch 23: val_accuracy improved from 0.70833 to 0.72500, saving model to /content/drive/MyDrive/app/T6/best_model2.h5
51/51 [==============================] - 9s 185ms/step - loss: 0.3208 - accuracy: 0.8827 - val_loss: 1.3418 - val_accuracy: 0.7250
Epoch 24/100
51/51 [==============================] - ETA: 0s - loss: 0.3316 - accuracy: 0.8673
Epoch 24: val_accuracy did not improve from 0.72500
51/51 [==============================] - 8s 162ms/step - loss: 0.3316 - accuracy: 0.8673 - val_loss: 1.3702 - val_accuracy: 0.7222
Epoch 25/100
51/51 [==============================] - ETA: 0s - loss: 0.3231 - accuracy: 0.8796
Epoch 25: val_accuracy did not improve from 0.72500
51/51 [==============================] - 8s 162ms/step - loss: 0.3231 - accuracy: 0.8796 - val_loss: 1.3799 - val_accuracy: 0.7222
Epoch 26/100
51/51 [==============================] - ETA: 0s - loss: 0.2842 - accuracy: 0.8920
Epoch 26: val_accuracy did not improve from 0.72500
51/51 [==============================] - 8s 162ms/step - loss: 0.2842 - accuracy: 0.8920 - val_loss: 1.3978 - val_accuracy: 0.7250
Epoch 27/100
51/51 [==============================] - ETA: 0s - loss: 0.3163 - accuracy: 0.8784
Epoch 27: val_accuracy did not improve from 0.72500
51/51 [==============================] - 8s 163ms/step - loss: 0.3163 - accuracy: 0.8784 - val_loss: 1.4428 - val_accuracy: 0.7250
Epoch 28/100
51/51 [==============================] - ETA: 0s - loss: 0.2910 - accuracy: 0.8883
Epoch 28: val_accuracy did not improve from 0.72500
51/51 [==============================] - 9s 182ms/step - loss: 0.2910 - accuracy: 0.8883 - val_loss: 1.4257 - val_accuracy: 0.7111
Epoch 29/100
51/51 [==============================] - ETA: 0s - loss: 0.2721 - accuracy: 0.9025
Epoch 29: val_accuracy did not improve from 0.72500
51/51 [==============================] - 8s 161ms/step - loss: 0.2721 - accuracy: 0.9025 - val_loss: 1.4045 - val_accuracy: 0.7139
Epoch 30/100
51/51 [==============================] - ETA: 0s - loss: 0.2741 - accuracy: 0.8957
Epoch 30: val_accuracy improved from 0.72500 to 0.74167, saving model to /content/drive/MyDrive/app/T6/best_model2.h5
51/51 [==============================] - 9s 181ms/step - loss: 0.2741 - accuracy: 0.8957 - val_loss: 1.4440 - val_accuracy: 0.7417
Epoch 31/100
51/51 [==============================] - ETA: 0s - loss: 0.2596 - accuracy: 0.9123
Epoch 31: val_accuracy did not improve from 0.74167
51/51 [==============================] - 8s 160ms/step - loss: 0.2596 - accuracy: 0.9123 - val_loss: 1.4244 - val_accuracy: 0.7306
Epoch 32/100
51/51 [==============================] - ETA: 0s - loss: 0.2355 - accuracy: 0.9228
Epoch 32: val_accuracy did not improve from 0.74167
51/51 [==============================] - 8s 161ms/step - loss: 0.2355 - accuracy: 0.9228 - val_loss: 1.3860 - val_accuracy: 0.7417
Epoch 33/100
51/51 [==============================] - ETA: 0s - loss: 0.2562 - accuracy: 0.8963
Epoch 33: val_accuracy did not improve from 0.74167
51/51 [==============================] - 8s 162ms/step - loss: 0.2562 - accuracy: 0.8963 - val_loss: 1.4298 - val_accuracy: 0.7417
Epoch 34/100
51/51 [==============================] - ETA: 0s - loss: 0.2079 - accuracy: 0.9210
Epoch 34: val_accuracy did not improve from 0.74167
51/51 [==============================] - 8s 162ms/step - loss: 0.2079 - accuracy: 0.9210 - val_loss: 1.4304 - val_accuracy: 0.7417
Epoch 35/100
51/51 [==============================] - ETA: 0s - loss: 0.1996 - accuracy: 0.9191
Epoch 35: val_accuracy did not improve from 0.74167
51/51 [==============================] - 8s 165ms/step - loss: 0.1996 - accuracy: 0.9191 - val_loss: 1.4654 - val_accuracy: 0.7389
Epoch 36/100
51/51 [==============================] - ETA: 0s - loss: 0.2363 - accuracy: 0.9099
Epoch 36: val_accuracy improved from 0.74167 to 0.74722, saving model to /content/drive/MyDrive/app/T6/best_model2.h5
51/51 [==============================] - 10s 202ms/step - loss: 0.2363 - accuracy: 0.9099 - val_loss: 1.4391 - val_accuracy: 0.7472
Epoch 37/100
51/51 [==============================] - ETA: 0s - loss: 0.2303 - accuracy: 0.9160
Epoch 37: val_accuracy did not improve from 0.74722
51/51 [==============================] - 8s 162ms/step - loss: 0.2303 - accuracy: 0.9160 - val_loss: 1.4713 - val_accuracy: 0.7444
Epoch 38/100
51/51 [==============================] - ETA: 0s - loss: 0.2382 - accuracy: 0.9068
Epoch 38: val_accuracy did not improve from 0.74722
51/51 [==============================] - 8s 162ms/step - loss: 0.2382 - accuracy: 0.9068 - val_loss: 1.4830 - val_accuracy: 0.7417
Epoch 39/100
51/51 [==============================] - ETA: 0s - loss: 0.1997 - accuracy: 0.9235
Epoch 39: val_accuracy did not improve from 0.74722
51/51 [==============================] - 8s 162ms/step - loss: 0.1997 - accuracy: 0.9235 - val_loss: 1.4800 - val_accuracy: 0.7278
Epoch 40/100
51/51 [==============================] - ETA: 0s - loss: 0.1872 - accuracy: 0.9278
Epoch 40: val_accuracy did not improve from 0.74722
51/51 [==============================] - 9s 182ms/step - loss: 0.1872 - accuracy: 0.9278 - val_loss: 1.4292 - val_accuracy: 0.7361
Epoch 41/100
51/51 [==============================] - ETA: 0s - loss: 0.1961 - accuracy: 0.9235
Epoch 41: val_accuracy did not improve from 0.74722
51/51 [==============================] - 8s 161ms/step - loss: 0.1961 - accuracy: 0.9235 - val_loss: 1.4463 - val_accuracy: 0.7444
Epoch 42/100
51/51 [==============================] - ETA: 0s - loss: 0.2036 - accuracy: 0.9284
Epoch 42: val_accuracy did not improve from 0.74722
51/51 [==============================] - 8s 162ms/step - loss: 0.2036 - accuracy: 0.9284 - val_loss: 1.4586 - val_accuracy: 0.7361
Epoch 43/100
51/51 [==============================] - ETA: 0s - loss: 0.1565 - accuracy: 0.9377
Epoch 43: val_accuracy did not improve from 0.74722
51/51 [==============================] - 8s 162ms/step - loss: 0.1565 - accuracy: 0.9377 - val_loss: 1.4763 - val_accuracy: 0.7417
Epoch 44/100
51/51 [==============================] - ETA: 0s - loss: 0.1738 - accuracy: 0.9364
Epoch 44: val_accuracy did not improve from 0.74722
51/51 [==============================] - 8s 162ms/step - loss: 0.1738 - accuracy: 0.9364 - val_loss: 1.4917 - val_accuracy: 0.7417
Epoch 45/100
51/51 [==============================] - ETA: 0s - loss: 0.1853 - accuracy: 0.9290
Epoch 45: val_accuracy did not improve from 0.74722
51/51 [==============================] - 8s 162ms/step - loss: 0.1853 - accuracy: 0.9290 - val_loss: 1.4507 - val_accuracy: 0.7333
Epoch 46/100
51/51 [==============================] - ETA: 0s - loss: 0.1699 - accuracy: 0.9383
Epoch 46: val_accuracy did not improve from 0.74722
51/51 [==============================] - 8s 162ms/step - loss: 0.1699 - accuracy: 0.9383 - val_loss: 1.4672 - val_accuracy: 0.7361
Epoch 47/100
51/51 [==============================] - ETA: 0s - loss: 0.1705 - accuracy: 0.9327
Epoch 47: val_accuracy did not improve from 0.74722
51/51 [==============================] - 8s 161ms/step - loss: 0.1705 - accuracy: 0.9327 - val_loss: 1.4683 - val_accuracy: 0.7333
Epoch 48/100
51/51 [==============================] - ETA: 0s - loss: 0.1315 - accuracy: 0.9531
Epoch 48: val_accuracy did not improve from 0.74722
51/51 [==============================] - 8s 162ms/step - loss: 0.1315 - accuracy: 0.9531 - val_loss: 1.4708 - val_accuracy: 0.7333
Epoch 49/100
51/51 [==============================] - ETA: 0s - loss: 0.1561 - accuracy: 0.9352
Epoch 49: val_accuracy did not improve from 0.74722
51/51 [==============================] - 8s 161ms/step - loss: 0.1561 - accuracy: 0.9352 - val_loss: 1.4800 - val_accuracy: 0.7306
Epoch 50/100
51/51 [==============================] - ETA: 0s - loss: 0.1485 - accuracy: 0.9420
Epoch 50: val_accuracy did not improve from 0.74722
51/51 [==============================] - 8s 161ms/step - loss: 0.1485 - accuracy: 0.9420 - val_loss: 1.5146 - val_accuracy: 0.7194
Epoch 51/100
51/51 [==============================] - ETA: 0s - loss: 0.1666 - accuracy: 0.9358
Epoch 51: val_accuracy did not improve from 0.74722
51/51 [==============================] - 9s 183ms/step - loss: 0.1666 - accuracy: 0.9358 - val_loss: 1.4996 - val_accuracy: 0.7389
Epoch 52/100
51/51 [==============================] - ETA: 0s - loss: 0.1649 - accuracy: 0.9333
Epoch 52: val_accuracy did not improve from 0.74722
51/51 [==============================] - 8s 161ms/step - loss: 0.1649 - accuracy: 0.9333 - val_loss: 1.4963 - val_accuracy: 0.7361
Epoch 53/100
51/51 [==============================] - ETA: 0s - loss: 0.1534 - accuracy: 0.9432
Epoch 53: val_accuracy did not improve from 0.74722
51/51 [==============================] - 9s 182ms/step - loss: 0.1534 - accuracy: 0.9432 - val_loss: 1.4637 - val_accuracy: 0.7472
Epoch 54/100
51/51 [==============================] - ETA: 0s - loss: 0.1498 - accuracy: 0.9457
Epoch 54: val_accuracy did not improve from 0.74722
51/51 [==============================] - 8s 161ms/step - loss: 0.1498 - accuracy: 0.9457 - val_loss: 1.4639 - val_accuracy: 0.7472
Epoch 55/100
51/51 [==============================] - ETA: 0s - loss: 0.1524 - accuracy: 0.9438
Epoch 55: val_accuracy did not improve from 0.74722
51/51 [==============================] - 8s 162ms/step - loss: 0.1524 - accuracy: 0.9438 - val_loss: 1.4952 - val_accuracy: 0.7361
Epoch 56/100
51/51 [==============================] - ETA: 0s - loss: 0.1234 - accuracy: 0.9568
Epoch 56: val_accuracy did not improve from 0.74722
51/51 [==============================] - 8s 162ms/step - loss: 0.1234 - accuracy: 0.9568 - val_loss: 1.5060 - val_accuracy: 0.7389
Epoch 56: early stopping

4、Loss与Accuracy图

在这里插入图片描述
模型性能肉眼可见的提升!

model.evaluate(val_ds)

输出:

12/12 [==============================] - 1s 116ms/step - loss: 1.5060 - accuracy: 0.7389
[1.5059789419174194, 0.7388888597488403]

5、 指定图片进行预测

# 加载效果最好的模型权重
model.load_weights('best_model.h5')
from PIL import Image
import numpy as np

img = Image.open("/content/drive/MyDrive/app/T6/48-data/Tom Cruise/001_08212dcd.jpg")  #这里选择你需要预测的图片
image = tf.image.resize(img, [img_height, img_width])

img_array = tf.expand_dims(image, 0)

predictions = model.predict(img_array) # 这里选用你已经训练好的模型
print("预测结果为:",class_names[np.argmax(predictions)])

输出:

1/1 [==============================] - 0s 30ms/step
预测结果为: Tom Cruise

预测正确。

七、总结

直接使用官网的VGG16模型,效果很差。冻结VGG卷积基,只训练Dense层,增加Dropout,效果提升很大。

  • 5
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值