TensorFlow2利用猫狗数据集(cats_and_dogs_filtered.zip)实现卷积神经网络完成分类任务

本案例使用TensorFlow2高级API完成猫狗数据集中猫和狗的分类任务,主要内容或采用的方法有:

  • 数据预处理:tf.keras.preprocessing.image.ImageDataGenerator
  • 防止过拟合
  • 数据增强和drop

1. 导入所需的库

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os

for i in [tf, np]:
    print(i.__name__,": ",i.__version__,sep="")

输出:

tensorflow: 2.2.0
numpy: 1.17.4

2. 下载并加载数据集

dataset_url = "https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip"

dataset_path = tf.keras.utils.get_file("cats_and_dogs_filtered.zip",origin=dataset_url,extract=True)
dataset_dir = os.path.join(os.path.dirname(dataset_path),"cats_and_dogs_filtered")

print(dataset_path)
print(dataset_dir)

输出:

C:\Users\my-pc\.keras\datasets\cats_and_dogs_filtered.zip
C:\Users\my-pc\.keras\datasets\cats_and_dogs_filtered

数据集目录结构:
cats_and_dogs_filtered
|__ train
    |______ cats: [cat.0.jpg, cat.1.jpg, cat.2.jpg ....]
    |______ dogs: [dog.0.jpg, dog.1.jpg, dog.2.jpg ...]
|__ validation
    |______ cats: [cat.2000.jpg, cat.2001.jpg, cat.2002.jpg ....]
    |______ dogs: [dog.2000.jpg, dog.2001.jpg, dog.2002.jpg ...]
|__ vectorize.py

train_cats = os.path.join(dataset_dir,"train","cats")
train_dogs = os.path.join(dataset_dir,"train","dogs")
validation_cats = os.path.join(dataset_dir,"validation","cats")
validation_dogs = os.path.join(dataset_dir,"validation","dogs")

train_dir = os.path.join(dataset_dir,"train")
validation_dir = os.path.join(dataset_dir,"validation")

for i in [train_dir,validation_dir, train_cats, train_dogs, validation_cats, validation_dogs]:
    print(i)

输出:

C:\Users\my-pc\.keras\datasets\cats_and_dogs_filtered\train
C:\Users\my-pc\.keras\datasets\cats_and_dogs_filtered\validation
C:\Users\my-pc\.keras\datasets\cats_and_dogs_filtered\train\cats
C:\Users\my-pc\.keras\datasets\cats_and_dogs_filtered\train\dogs
C:\Users\my-pc\.keras\datasets\cats_and_dogs_filtered\validation\cats
C:\Users\my-pc\.keras\datasets\cats_and_dogs_filtered\validation\dogs

3. 数据初探

3.1 统计训练集和验证集大小

train_cats_num = len(os.listdir(train_cats))
train_dogs_num = len(os.listdir(train_dogs))

validation_cats_num = len(os.listdir(validation_cats))
validation_dogs_num = len(os.listdir(validation_dogs))

train_all = train_cats_num + train_dogs_num
validation_all = validation_cats_num + validation_dogs_num

print("train cats: ",train_cats_num)
print("train dogs: ",train_dogs_num)
print("validation cats: ",validation_cats_num)
print("validation_dogs: ",validation_dogs_num)
print("all train images: ",train_all)
print("all validation images: ",validation_all)

输出:

train cats:  1000
train dogs:  1000
validation cats:  500
validation_dogs:  500
all train images:  2000
all validation images:  1000

3.2 设置超参数

batch_size = 128
epochs = 50
height = 150
width = 150

4. 数据预处理

图像数据预处理需要做以下工作:

  • 从磁盘读入图像数据
  • 解码图像内容并转成合适的格式
  • 将其转换成浮点张量
  • 将数值从0至255归一化到0至1之间的数值
train_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1.0/255)
validation_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1.0/255)

train_data_gen = train_generator.flow_from_directory(batch_size=batch_size,
                                                    directory=train_dir,
                                                    shuffle=True,
                                                    target_size=(height,width),
                                                    class_mode="binary")

输出:

Found 2000 images belonging to 2 classes.
val_data_gen = validation_generator.flow_from_directory(batch_size=batch_size,
                                                       directory=validation_dir,
                                                       target_size=(height,width),
                                                       class_mode="binary")

输出:

Found 1000 images belonging to 2 classes.

5. 数据图像可视化

sample_training_images, _ = next(train_data_gen)

def plotImages(images_arr):
    fig, axes = plt.subplots(1,5,figsize=(20,20))
    axes = axes.flatten()
    for img, ax in zip(images_arr, axes):
        ax.imshow(img)
        ax.axis("off")
    plt.tight_layout()
    plt.show()

plotImages(sample_training_images[:5])

输出:

6. 模型构建

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(16,3,padding="same",activation="relu",input_shape=(height,width,3)),
    tf.keras.layers.MaxPooling2D(),  # 默认池化核为2*2
    tf.keras.layers.Conv2D(32,3,padding="same",activation="relu"),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Conv2D(64,3,padding="same",activation="relu"),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(512,activation="relu"),
    tf.keras.layers.Dense(1)    
])

model.compile(optimizer="adam",
             loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
             metrics=["accuracy"])

model.summary()

输出:

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (None, 150, 150, 16)      448       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 75, 75, 16)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 75, 75, 32)        4640      
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 37, 37, 32)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 37, 37, 64)        18496     
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 18, 18, 64)        0         
_________________________________________________________________
flatten (Flatten)            (None, 20736)             0         
_________________________________________________________________
dense (Dense)                (None, 512)               10617344  
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 513       
=================================================================
Total params: 10,641,441
Trainable params: 10,641,441
Non-trainable params: 0
_________________________________________________________________

7. 训练模型

history = model.fit_generator(train_data_gen,
                             steps_per_epoch=train_all//batch_size,
                             epochs=epochs,
                             validation_data=val_data_gen,
                             validation_steps=validation_all//batch_size)

输出:

Epoch 1/50
15/15 [==============================] - 4s 263ms/step - loss: 0.7677 - accuracy: 0.5000 - val_loss: 0.6865 - val_accuracy: 0.4955
Epoch 2/50
15/15 [==============================] - 4s 255ms/step - loss: 0.6777 - accuracy: 0.5339 - val_loss: 0.6408 - val_accuracy: 0.5804
Epoch 3/50
15/15 [==============================] - 4s 279ms/step - loss: 0.6578 - accuracy: 0.5550 - val_loss: 0.6475 - val_accuracy: 0.5859
Epoch 4/50
15/15 [==============================] - 4s 250ms/step - loss: 0.6156 - accuracy: 0.6389 - val_loss: 0.6504 - val_accuracy: 0.6886
Epoch 5/50
15/15 [==============================] - 4s 234ms/step - loss: 0.5880 - accuracy: 0.6757 - val_loss: 0.5863 - val_accuracy: 0.6339
Epoch 6/50
15/15 [==============================] - 4s 242ms/step - loss: 0.5080 - accuracy: 0.7329 - val_loss: 0.6115 - val_accuracy: 0.7154
Epoch 7/50
15/15 [==============================] - 4s 249ms/step - loss: 0.4502 - accuracy: 0.7724 - val_loss: 0.5791 - val_accuracy: 0.7243
Epoch 8/50
15/15 [==============================] - 3s 228ms/step - loss: 0.4127 - accuracy: 0.8040 - val_loss: 0.5670 - val_accuracy: 0.6975
Epoch 9/50
15/15 [==============================] - 3s 228ms/step - loss: 0.3744 - accuracy: 0.8173 - val_loss: 0.5866 - val_accuracy: 0.6797
Epoch 10/50
15/15 [==============================] - 3s 230ms/step - loss: 0.3185 - accuracy: 0.8365 - val_loss: 0.6255 - val_accuracy: 0.7121
Epoch 11/50
15/15 [==============================] - 3s 225ms/step - loss: 0.2731 - accuracy: 0.8750 - val_loss: 0.6398 - val_accuracy: 0.7132
Epoch 12/50
15/15 [==============================] - 3s 222ms/step - loss: 0.2391 - accuracy: 0.8926 - val_loss: 0.6647 - val_accuracy: 0.7031
Epoch 13/50
15/15 [==============================] - 3s 222ms/step - loss: 0.1790 - accuracy: 0.9257 - val_loss: 0.7181 - val_accuracy: 0.7188
Epoch 14/50
15/15 [==============================] - 3s 233ms/step - loss: 0.1584 - accuracy: 0.9386 - val_loss: 0.8213 - val_accuracy: 0.6942
Epoch 15/50
15/15 [==============================] - 3s 232ms/step - loss: 0.1195 - accuracy: 0.9605 - val_loss: 1.0006 - val_accuracy: 0.6808
Epoch 16/50
15/15 [==============================] - 3s 226ms/step - loss: 0.0907 - accuracy: 0.9679 - val_loss: 0.9743 - val_accuracy: 0.7031
Epoch 17/50
15/15 [==============================] - 4s 260ms/step - loss: 0.0700 - accuracy: 0.9802 - val_loss: 1.0655 - val_accuracy: 0.6830
Epoch 18/50
15/15 [==============================] - 4s 259ms/step - loss: 0.0515 - accuracy: 0.9870 - val_loss: 1.1483 - val_accuracy: 0.6942
Epoch 19/50
15/15 [==============================] - 4s 237ms/step - loss: 0.0373 - accuracy: 0.9931 - val_loss: 1.2806 - val_accuracy: 0.7020
Epoch 20/50
15/15 [==============================] - 3s 230ms/step - loss: 0.0324 - accuracy: 0.9915 - val_loss: 1.2004 - val_accuracy: 0.7154
Epoch 21/50
15/15 [==============================] - 4s 259ms/step - loss: 0.0224 - accuracy: 0.9952 - val_loss: 1.3527 - val_accuracy: 0.6964
Epoch 22/50
15/15 [==============================] - 4s 257ms/step - loss: 0.0229 - accuracy: 0.9952 - val_loss: 1.3227 - val_accuracy: 0.7031
Epoch 23/50
15/15 [==============================] - 4s 235ms/step - loss: 0.0148 - accuracy: 0.9989 - val_loss: 1.3873 - val_accuracy: 0.6975
Epoch 24/50
15/15 [==============================] - 4s 238ms/step - loss: 0.0116 - accuracy: 0.9979 - val_loss: 1.4497 - val_accuracy: 0.7277
Epoch 25/50
15/15 [==============================] - 3s 225ms/step - loss: 0.0202 - accuracy: 0.9941 - val_loss: 1.4114 - val_accuracy: 0.7210
Epoch 26/50
15/15 [==============================] - 3s 228ms/step - loss: 0.0261 - accuracy: 0.9931 - val_loss: 1.4986 - val_accuracy: 0.7020
Epoch 27/50
15/15 [==============================] - 3s 228ms/step - loss: 0.0129 - accuracy: 0.9979 - val_loss: 1.6209 - val_accuracy: 0.7054
Epoch 28/50
15/15 [==============================] - 3s 226ms/step - loss: 0.0087 - accuracy: 0.9984 - val_loss: 1.7243 - val_accuracy: 0.7132
Epoch 29/50
15/15 [==============================] - 3s 220ms/step - loss: 0.0046 - accuracy: 1.0000 - val_loss: 1.8122 - val_accuracy: 0.6942
Epoch 30/50
15/15 [==============================] - 3s 224ms/step - loss: 0.0023 - accuracy: 1.0000 - val_loss: 1.7980 - val_accuracy: 0.7143
Epoch 31/50
15/15 [==============================] - 3s 225ms/step - loss: 0.0017 - accuracy: 1.0000 - val_loss: 1.7874 - val_accuracy: 0.7221
Epoch 32/50
15/15 [==============================] - 3s 220ms/step - loss: 0.0014 - accuracy: 1.0000 - val_loss: 1.8793 - val_accuracy: 0.7143
Epoch 33/50
15/15 [==============================] - 3s 217ms/step - loss: 0.0011 - accuracy: 1.0000 - val_loss: 1.8711 - val_accuracy: 0.7143
Epoch 34/50
15/15 [==============================] - 3s 224ms/step - loss: 0.0010 - accuracy: 1.0000 - val_loss: 1.9720 - val_accuracy: 0.7132
Epoch 35/50
15/15 [==============================] - 3s 226ms/step - loss: 8.4621e-04 - accuracy: 1.0000 - val_loss: 2.0077 - val_accuracy: 0.6998
Epoch 36/50
15/15 [==============================] - 3s 223ms/step - loss: 7.9129e-04 - accuracy: 1.0000 - val_loss: 2.0034 - val_accuracy: 0.7009
Epoch 37/50
15/15 [==============================] - 3s 223ms/step - loss: 8.0383e-04 - accuracy: 1.0000 - val_loss: 2.0318 - val_accuracy: 0.7087
Epoch 38/50
15/15 [==============================] - 3s 221ms/step - loss: 6.2408e-04 - accuracy: 1.0000 - val_loss: 2.0375 - val_accuracy: 0.7132
Epoch 39/50
15/15 [==============================] - 3s 223ms/step - loss: 6.5080e-04 - accuracy: 1.0000 - val_loss: 2.0763 - val_accuracy: 0.7132
Epoch 40/50
15/15 [==============================] - 3s 229ms/step - loss: 4.8377e-04 - accuracy: 1.0000 - val_loss: 2.0136 - val_accuracy: 0.7121
Epoch 41/50
15/15 [==============================] - 3s 230ms/step - loss: 5.8370e-04 - accuracy: 1.0000 - val_loss: 2.1221 - val_accuracy: 0.7121
Epoch 42/50
15/15 [==============================] - 3s 220ms/step - loss: 4.9391e-04 - accuracy: 1.0000 - val_loss: 2.0811 - val_accuracy: 0.7042
Epoch 43/50
15/15 [==============================] - 3s 221ms/step - loss: 4.5191e-04 - accuracy: 1.0000 - val_loss: 2.0469 - val_accuracy: 0.7076
Epoch 44/50
15/15 [==============================] - 3s 225ms/step - loss: 4.6518e-04 - accuracy: 1.0000 - val_loss: 2.1188 - val_accuracy: 0.7087
Epoch 45/50
15/15 [==============================] - 3s 227ms/step - loss: 4.3731e-04 - accuracy: 1.0000 - val_loss: 2.0347 - val_accuracy: 0.7154
Epoch 46/50
15/15 [==============================] - 3s 231ms/step - loss: 3.5848e-04 - accuracy: 1.0000 - val_loss: 2.0776 - val_accuracy: 0.7087
Epoch 47/50
15/15 [==============================] - 3s 221ms/step - loss: 3.6341e-04 - accuracy: 1.0000 - val_loss: 2.2015 - val_accuracy: 0.7087
Epoch 48/50
15/15 [==============================] - 3s 221ms/step - loss: 3.5506e-04 - accuracy: 1.0000 - val_loss: 2.1296 - val_accuracy: 0.7188
Epoch 49/50
15/15 [==============================] - 3s 221ms/step - loss: 3.1891e-04 - accuracy: 1.0000 - val_loss: 2.1128 - val_accuracy: 0.7098
Epoch 50/50
15/15 [==============================] - 3s 225ms/step - loss: 3.1507e-04 - accuracy: 1.0000 - val_loss: 2.2647 - val_accuracy: 0.7020

8. 训练结果可视化

accuracy = history.history["accuracy"]
val_accuracy = history.history["val_accuracy"]
loss = history.history["loss"]
val_loss = history.history["val_loss"]
    
epochs_range=range(epochs)

plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
plt.plot(epochs_range, accuracy, label="Training Accuracy")
plt.plot(epochs_range, val_accuracy,label="Validation Accuracy")
plt.legend()
plt.title("Training and Validation Accuracy")

plt.subplot(1,2,2)
plt.plot(epochs_range, loss, label="Training Loss")
plt.plot(epochs_range, val_loss,label="Validation Loss")
plt.legend()
plt.title("Training and Validation Loss")
plt.show()

输出:

从上图可以看到,模型在训练集上准确率达到了1,而在训练集上准确率较低,这就是模型过拟合造成的。为了解决过拟合的问题,需要对数据进行数据增强,在模型结构中加入Dropout层。

9. 数据增强

9.1 水平翻转

image_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255,horizontal_flip=True)

train_data_gen = image_gen.flow_from_directory(batch_size=batch_size,
                                              directory=train_dir,
                                              shuffle=True,
                                              target_size=(height,width))

输出:

Found 2000 images belonging to 2 classes.
augmented_images = [train_data_gen[0][0][0] for i in range(5)]

plotImages(augmented_images)

输出:

9.2 随机旋转

image_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255,rotation_range=45)

train_data_gen = image_gen.flow_from_directory(batch_size=batch_size,
                                              directory=train_dir,
                                              shuffle=True,
                                              target_size=(height,width))

augmented_images = [train_data_gen[0][0][0] for i in range(5)]
plotImages(augmented_images)

输出:

Found 2000 images belonging to 2 classes.

9.3 放大操作

image_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255,zoom_range=0.5)

train_data_gen = image_gen.flow_from_directory(batch_size=batch_size,
                                              directory=train_dir,
                                              shuffle=True,
                                              target_size=(height,width))

augmented_images = [train_data_gen[0][0][0] for i in range(5)]
plotImages(augmented_images)

输出:

Found 2000 images belonging to 2 classes.

9.4 不同增强手段同时使用

image_gen_train = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255,
                                                                 rotation_range=45,
                                                                 width_shift_range=.15,
                                                                 height_shift_range=.15,
                                                                 horizontal_flip=True,
                                                                 zoom_range=0.5)

train_data_gen = image_gen_train.flow_from_directory(batch_size=batch_size,
                                                    directory=train_dir,
                                                    shuffle=True,
                                                    target_size=(height,width),
                                                    class_mode="binary")
augmented_images = [train_data_gen[0][0][0] for i in range(5)]
plotImages(augmented_images)

输出:

Found 2000 images belonging to 2 classes.

image_gen_val = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
val_data_gen = image_gen_val.flow_from_directory(batch_size=batch_size,
                                                directory=validation_dir,
                                                target_size=(height,width),
                                                class_mode="binary")

输出:

Found 1000 images belonging to 2 classes.

10. 模型中加入Dropout

model_new = tf.keras.Sequential([
    tf.keras.layers.Conv2D(16,3,padding="same",activation="relu",input_shape=(height,width,3)),
    tf.keras.layers.MaxPooling2D(),  # 默认池化核为2*2
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Conv2D(32,3,padding="same",activation="relu"),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Conv2D(64,3,padding="same",activation="relu"),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(512,activation="relu"),
    tf.keras.layers.Dense(1)    
])

model_new.compile(optimizer="adam",
             loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
             metrics=["accuracy"])

model_new.summary()

输出:

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_6 (Conv2D)            (None, 150, 150, 16)      448       
_________________________________________________________________
max_pooling2d_6 (MaxPooling2 (None, 75, 75, 16)        0         
_________________________________________________________________
dropout_2 (Dropout)          (None, 75, 75, 16)        0         
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 75, 75, 32)        4640      
_________________________________________________________________
max_pooling2d_7 (MaxPooling2 (None, 37, 37, 32)        0         
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 37, 37, 64)        18496     
_________________________________________________________________
max_pooling2d_8 (MaxPooling2 (None, 18, 18, 64)        0         
_________________________________________________________________
dropout_3 (Dropout)          (None, 18, 18, 64)        0         
_________________________________________________________________
flatten_2 (Flatten)          (None, 20736)             0         
_________________________________________________________________
dense_4 (Dense)              (None, 512)               10617344  
_________________________________________________________________
dense_5 (Dense)              (None, 1)                 513       
=================================================================
Total params: 10,641,441
Trainable params: 10,641,441
Non-trainable params: 0
_________________________________________________________________

11. 训练新模型

history_new = model_new.fit_generator(train_data_gen,
                                     steps_per_epoch=train_all//batch_size,
                                     epochs=epochs,
                                     validation_data=val_data_gen,
                                     validation_steps=validation_all//batch_size)

输出:

Epoch 1/50
15/15 [==============================] - 10s 684ms/step - loss: 0.7637 - accuracy: 0.5141 - val_loss: 0.6931 - val_accuracy: 0.5045
Epoch 2/50
15/15 [==============================] - 10s 641ms/step - loss: 0.6933 - accuracy: 0.4973 - val_loss: 0.6913 - val_accuracy: 0.5000
Epoch 3/50
15/15 [==============================] - 10s 646ms/step - loss: 0.6914 - accuracy: 0.5075 - val_loss: 0.6904 - val_accuracy: 0.4989
Epoch 4/50
15/15 [==============================] - 10s 637ms/step - loss: 0.6866 - accuracy: 0.4952 - val_loss: 0.6821 - val_accuracy: 0.5045
Epoch 5/50
15/15 [==============================] - 10s 638ms/step - loss: 0.6734 - accuracy: 0.5278 - val_loss: 0.6832 - val_accuracy: 0.4989
Epoch 6/50
15/15 [==============================] - 9s 624ms/step - loss: 0.6605 - accuracy: 0.5759 - val_loss: 0.6376 - val_accuracy: 0.5748
Epoch 7/50
15/15 [==============================] - 9s 598ms/step - loss: 0.6420 - accuracy: 0.5919 - val_loss: 0.6175 - val_accuracy: 0.6507
Epoch 8/50
15/15 [==============================] - 9s 597ms/step - loss: 0.6252 - accuracy: 0.6154 - val_loss: 0.6036 - val_accuracy: 0.6440
Epoch 9/50
15/15 [==============================] - 9s 617ms/step - loss: 0.6100 - accuracy: 0.6223 - val_loss: 0.5864 - val_accuracy: 0.6272
Epoch 10/50
15/15 [==============================] - 9s 627ms/step - loss: 0.6159 - accuracy: 0.6368 - val_loss: 0.5976 - val_accuracy: 0.6574
Epoch 11/50
15/15 [==============================] - 10s 634ms/step - loss: 0.6031 - accuracy: 0.6314 - val_loss: 0.5692 - val_accuracy: 0.6585
Epoch 12/50
15/15 [==============================] - 10s 656ms/step - loss: 0.5707 - accuracy: 0.6875 - val_loss: 0.5565 - val_accuracy: 0.6975
Epoch 13/50
15/15 [==============================] - 10s 651ms/step - loss: 0.5981 - accuracy: 0.6512 - val_loss: 0.5727 - val_accuracy: 0.6752
Epoch 14/50
15/15 [==============================] - 10s 640ms/step - loss: 0.5715 - accuracy: 0.6768 - val_loss: 0.5363 - val_accuracy: 0.6987
Epoch 15/50
15/15 [==============================] - 9s 601ms/step - loss: 0.5906 - accuracy: 0.6485 - val_loss: 0.5589 - val_accuracy: 0.6540
Epoch 16/50
15/15 [==============================] - 9s 627ms/step - loss: 0.5871 - accuracy: 0.6725 - val_loss: 0.5651 - val_accuracy: 0.6440
Epoch 17/50
15/15 [==============================] - 9s 622ms/step - loss: 0.5724 - accuracy: 0.6667 - val_loss: 0.5524 - val_accuracy: 0.6741
Epoch 18/50
15/15 [==============================] - 9s 609ms/step - loss: 0.5510 - accuracy: 0.6886 - val_loss: 0.5363 - val_accuracy: 0.6819
Epoch 19/50
15/15 [==============================] - 9s 589ms/step - loss: 0.5790 - accuracy: 0.6629 - val_loss: 0.5671 - val_accuracy: 0.6283
Epoch 20/50
15/15 [==============================] - 9s 592ms/step - loss: 0.5704 - accuracy: 0.6704 - val_loss: 0.5361 - val_accuracy: 0.6942
Epoch 21/50
15/15 [==============================] - 9s 591ms/step - loss: 0.5361 - accuracy: 0.7051 - val_loss: 0.5303 - val_accuracy: 0.7132
Epoch 22/50
15/15 [==============================] - 9s 590ms/step - loss: 0.5428 - accuracy: 0.7089 - val_loss: 0.5417 - val_accuracy: 0.6741
Epoch 23/50
15/15 [==============================] - 9s 590ms/step - loss: 0.5527 - accuracy: 0.6891 - val_loss: 0.5289 - val_accuracy: 0.7042
Epoch 24/50
15/15 [==============================] - 9s 594ms/step - loss: 0.5365 - accuracy: 0.7110 - val_loss: 0.5222 - val_accuracy: 0.6853
Epoch 25/50
15/15 [==============================] - 9s 589ms/step - loss: 0.5314 - accuracy: 0.7169 - val_loss: 0.5321 - val_accuracy: 0.6875
Epoch 26/50
15/15 [==============================] - 9s 621ms/step - loss: 0.5469 - accuracy: 0.6768 - val_loss: 0.5188 - val_accuracy: 0.7199
Epoch 27/50
15/15 [==============================] - 9s 612ms/step - loss: 0.5281 - accuracy: 0.7249 - val_loss: 0.5133 - val_accuracy: 0.7333
Epoch 28/50
15/15 [==============================] - 9s 594ms/step - loss: 0.5367 - accuracy: 0.7137 - val_loss: 0.5380 - val_accuracy: 0.7299
Epoch 29/50
15/15 [==============================] - 9s 602ms/step - loss: 0.5470 - accuracy: 0.7115 - val_loss: 0.5354 - val_accuracy: 0.7042
Epoch 30/50
15/15 [==============================] - 9s 608ms/step - loss: 0.5403 - accuracy: 0.7099 - val_loss: 0.5260 - val_accuracy: 0.6897
Epoch 31/50
15/15 [==============================] - 10s 648ms/step - loss: 0.5257 - accuracy: 0.7276 - val_loss: 0.5269 - val_accuracy: 0.6942
Epoch 32/50
15/15 [==============================] - 9s 603ms/step - loss: 0.5275 - accuracy: 0.7035 - val_loss: 0.5461 - val_accuracy: 0.6864
Epoch 33/50
15/15 [==============================] - 9s 605ms/step - loss: 0.5307 - accuracy: 0.7198 - val_loss: 0.5752 - val_accuracy: 0.6696
Epoch 34/50
15/15 [==============================] - 9s 608ms/step - loss: 0.5253 - accuracy: 0.7228 - val_loss: 0.5099 - val_accuracy: 0.7042
Epoch 35/50
15/15 [==============================] - 9s 590ms/step - loss: 0.5002 - accuracy: 0.7441 - val_loss: 0.5375 - val_accuracy: 0.6920
Epoch 36/50
15/15 [==============================] - 9s 595ms/step - loss: 0.5177 - accuracy: 0.7334 - val_loss: 0.5177 - val_accuracy: 0.7310
Epoch 37/50
15/15 [==============================] - 9s 594ms/step - loss: 0.5083 - accuracy: 0.7350 - val_loss: 0.4946 - val_accuracy: 0.7377
Epoch 38/50
15/15 [==============================] - 9s 592ms/step - loss: 0.5352 - accuracy: 0.7137 - val_loss: 0.5071 - val_accuracy: 0.7087
Epoch 39/50
15/15 [==============================] - 9s 593ms/step - loss: 0.5055 - accuracy: 0.7249 - val_loss: 0.4995 - val_accuracy: 0.7020
Epoch 40/50
15/15 [==============================] - 9s 592ms/step - loss: 0.5122 - accuracy: 0.7329 - val_loss: 0.4876 - val_accuracy: 0.7478
Epoch 41/50
15/15 [==============================] - 9s 591ms/step - loss: 0.5174 - accuracy: 0.7377 - val_loss: 0.5005 - val_accuracy: 0.7511
Epoch 42/50
15/15 [==============================] - 9s 610ms/step - loss: 0.5140 - accuracy: 0.7340 - val_loss: 0.4946 - val_accuracy: 0.7132
Epoch 43/50
15/15 [==============================] - 9s 603ms/step - loss: 0.4989 - accuracy: 0.7420 - val_loss: 0.4774 - val_accuracy: 0.7545
Epoch 44/50
15/15 [==============================] - 10s 678ms/step - loss: 0.5004 - accuracy: 0.7431 - val_loss: 0.4940 - val_accuracy: 0.7310
Epoch 45/50
15/15 [==============================] - 9s 609ms/step - loss: 0.4834 - accuracy: 0.7516 - val_loss: 0.5288 - val_accuracy: 0.6964
Epoch 46/50
15/15 [==============================] - 9s 618ms/step - loss: 0.5017 - accuracy: 0.7377 - val_loss: 0.5231 - val_accuracy: 0.7009
Epoch 47/50
15/15 [==============================] - 9s 612ms/step - loss: 0.5044 - accuracy: 0.7276 - val_loss: 0.4883 - val_accuracy: 0.7712
Epoch 48/50
15/15 [==============================] - 9s 596ms/step - loss: 0.4893 - accuracy: 0.7580 - val_loss: 0.5250 - val_accuracy: 0.6964
Epoch 49/50
15/15 [==============================] - 9s 612ms/step - loss: 0.4843 - accuracy: 0.7328 - val_loss: 0.4848 - val_accuracy: 0.7344
Epoch 50/50
15/15 [==============================] - 9s 628ms/step - loss: 0.4827 - accuracy: 0.7527 - val_loss: 0.4828 - val_accuracy: 0.7299

12. 结果可视化

accuracy = history_new.history["accuracy"]
val_accuracy = history_new.history["val_accuracy"]
loss = history_new.history["loss"]
val_loss = history_new.history["val_loss"]
    
epochs_range=range(epochs)

plt.figure(figsize=(8,8))
plt.subplot(1,2,1)
plt.plot(epochs_range, accuracy, label="Training Accuracy")
plt.plot(epochs_range, val_accuracy,label="Validation Accuracy")
plt.legend()
plt.title("Training and Validation Accuracy")

plt.subplot(1,2,2)
plt.plot(epochs_range, loss, label="Training Loss")
plt.plot(epochs_range, val_loss,label="Validation Loss")
plt.legend()
plt.title("Training and Validation Loss")
plt.show()

输出:

从上图结果可以看出,过拟合问题得到了较好地解决。

 

 

 

 

 

 

  • 8
    点赞
  • 42
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值