TensorFlow2利用tf.image实现数据增强

本案例使用tf.image完成图像操作和预处理的过程,数据增强是防止过拟合的最常用的手段。

1. 导入所需的库

import tensorflow as tf
import numpy as np
import tensorflow_datasets as tfds
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
from PIL import Image

for i in [tf,np,tfds,mpl]:
    print(i.__name__,": ",i.__version__,sep="")

输出:

tensorflow: 2.2.0
numpy: 1.17.4
tensorflow_datasets: 3.1.0
matplotlib: 3.1.2

2. 下载案例图像

image_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg"
image_path = tf.keras.utils.get_file("cat.jpg", image_url,cache_dir="./")

Image.open(image_path)

输出:

Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg
24576/17858 [=========================================] - 0s 3us/step

image_string = tf.io.read_file(image_path)
image = tf.image.decode_jpeg(image_string,channels=3)
image.shape

输出:

TensorShape([213, 320, 3])

定义一个用于比对原始图像和数据增强后图像的函数

def visualize(original, augmented):
    plt.figure(figsize=(10,5))
    plt.subplot(1,2,1)
    plt.title("Original image")
    #plt.axis("off")  # 关闭坐标轴显示
    plt.imshow(original)
    
    plt.subplot(1,2,2)
    plt.title("Augmented image")
    #plt.axis("off")  # 关闭坐标轴显示
    plt.imshow(augmented)
    
    plt.tight_layout()
    plt.show()
    
visualize(image,image)

输出:

3. 单个图像数据增强

3.1 翻转图像(垂直和水平)

flipped_h = tf.image.flip_left_right(image)
flipped_v = tf.image.flip_up_down(image)

visualize(image,flipped_h)
visualize(image,flipped_v)

输出:

3.2 图像灰阶化

图像灰阶化是将彩色图像转变成灰度图像,通道数由3变为1。

grayscaled = tf.image.rgb_to_grayscale(image)
print(grayscaled.shape)
print(tf.squeeze(grayscaled).shape)

输出:

(213, 320, 1)
(213, 320)

tf.squeeze()删除掉那些维度为1的维。

visualize(image, tf.squeeze(grayscaled))

输出:

3.3 改变图像饱和度

saturated_3 = tf.image.adjust_saturation(image,3)
saturated_8 = tf.image.adjust_saturation(image,8)

visualize(image, saturated_3)
visualize(image, saturated_8)

输出:

3.4 改变图像亮度

bright_0 = tf.image.adjust_brightness(image, 0)
bright_5 = tf.image.adjust_brightness(image, 0.5)
bright_8 = tf.image.adjust_brightness(image, 0.8)
bright_10 = tf.image.adjust_brightness(image, 1)

visualize(image, bright_0)
visualize(image, bright_5)
visualize(image, bright_8)
visualize(image, bright_10)

输出:

3.5 图像旋转

rotated_90 = tf.image.rot90(image)
rotated_180 = tf.image.rot90(rotated_90)
rotated_270 = tf.image.rot90(rotated_180)  # 相当于上下翻转
rotated_360 = tf.image.rot90(rotated_270)

visualize(image, rotated_90)
visualize(image, rotated_180)
visualize(image, rotated_270)
visualize(image, rotated_360)

输出:

3.6 图像裁剪

# central_fraction取值范围为(0,1]
cropped_1 = tf.image.central_crop(image, central_fraction=0.1)
cropped_5 = tf.image.central_crop(image, central_fraction=0.5)
cropped_8 = tf.image.central_crop(image, central_fraction=0.8)
cropped_10 = tf.image.central_crop(image, central_fraction=1)

for i in [cropped_1, cropped_5, cropped_8, cropped_10]:
    visualize(image, i)

输出:

4. 数据集数据增强

4.1 下载并导入数据集

dataset, info = tfds.load("mnist",as_supervised=True, with_info=True)
train_dataset, test_dataset = dataset["train"],dataset["test"]

输出:

Downloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to C:\Users\my-pc\tensorflow_datasets\mnist\3.0.1...

Dl Completed...: 100%
4/4 [00:34<00:00, 8.66s/ file]

Dataset mnist downloaded and prepared to C:\Users\my-pc\tensorflow_datasets\mnist\3.0.1. Subsequent calls will reuse this data.
num_train_examples = info.splits["train"].num_examples
num_train_examples

输出:

60000

4.2 对数据集进行数据增强

def convert(image, label):
    image = tf.image.convert_image_dtype(image, tf.float32)
    return image, label

def augment(image, label):
    image, label = convert(image, label)
    #image = tf.image.convert_image_dtype(image, tf.float32)
    image = tf.image.resize_with_crop_or_pad(image, 34,34) # 四周各加3像素
    image = tf.image.random_crop(image, size=[28,28,1]) # 随机裁剪成28*28大小
    image = tf.image.random_brightness(image, max_delta=0.5) # 随机增加亮度
    return image, label

batch_size = 64
num_examples = 2048  # 使用较少数据量,以展现出过拟合

augmented_train_batches = (train_dataset
                          .take(num_examples)
                          .cache()
                          .shuffle(num_train_examples//4)
                          .map(augment, num_parallel_calls=tf.data.experimental.AUTOTUNE)
                          .batch(batch_size)
                          .prefetch(tf.data.experimental.AUTOTUNE))

non_augmented_train_batches = (train_dataset
                              .take(num_examples)
                              .cache()
                              .shuffle(num_train_examples//4)
                              .map(convert, num_parallel_calls=tf.data.experimental.AUTOTUNE)
                              .batch(batch_size)
                              .prefetch(tf.data.experimental.AUTOTUNE))

validation_batches = (test_dataset
                     .map(convert, num_parallel_calls=tf.data.experimental.AUTOTUNE)
                     .batch(2*batch_size))

5. 构建模型并训练

5.1 模型构建函数

def bulidModel():
    model = tf.keras.Sequential([
        tf.keras.layers.Flatten(input_shape=(28,28,1)),
        tf.keras.layers.Dense(4096, activation="relu"),
        tf.keras.layers.Dense(4096, activation="relu"),
        tf.keras.layers.Dense(10)
    ])
    model.summary()

    model.compile(optimizer="adam",
                 loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                 metrics=["accuracy"])
    return model

5.2 使用未增强的数据进行模型训练

model_non_augment = bulidModel()
history_non_augment = model_non_augment.fit(non_augmented_train_batches,
                                            epochs=50,
                                           validation_data=validation_batches)

输出:

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
flatten_2 (Flatten)          (None, 784)               0         
_________________________________________________________________
dense_6 (Dense)              (None, 4096)              3215360   
_________________________________________________________________
dense_7 (Dense)              (None, 4096)              16781312  
_________________________________________________________________
dense_8 (Dense)              (None, 10)                40970     
=================================================================
Total params: 20,037,642
Trainable params: 20,037,642
Non-trainable params: 0
_________________________________________________________________
Epoch 1/50
32/32 [==============================] - 1s 31ms/step - loss: 0.7841 - accuracy: 0.7520 - val_loss: 0.3936 - val_accuracy: 0.8803
Epoch 2/50
32/32 [==============================] - 1s 23ms/step - loss: 0.1908 - accuracy: 0.9399 - val_loss: 0.4234 - val_accuracy: 0.8898
Epoch 3/50
32/32 [==============================] - 1s 23ms/step - loss: 0.0936 - accuracy: 0.9712 - val_loss: 0.2674 - val_accuracy: 0.9270
Epoch 4/50
32/32 [==============================] - 1s 23ms/step - loss: 0.0475 - accuracy: 0.9878 - val_loss: 0.2828 - val_accuracy: 0.9247
Epoch 5/50
32/32 [==============================] - 1s 22ms/step - loss: 0.0370 - accuracy: 0.9863 - val_loss: 0.3070 - val_accuracy: 0.9245
Epoch 6/50
32/32 [==============================] - 1s 23ms/step - loss: 0.0551 - accuracy: 0.9819 - val_loss: 0.3165 - val_accuracy: 0.9264
Epoch 7/50
32/32 [==============================] - 1s 22ms/step - loss: 0.0233 - accuracy: 0.9951 - val_loss: 0.4092 - val_accuracy: 0.9121
Epoch 8/50
32/32 [==============================] - 1s 23ms/step - loss: 0.0193 - accuracy: 0.9946 - val_loss: 0.3974 - val_accuracy: 0.9171
Epoch 9/50
32/32 [==============================] - 1s 23ms/step - loss: 0.0095 - accuracy: 0.9971 - val_loss: 0.3894 - val_accuracy: 0.9208
Epoch 10/50
32/32 [==============================] - 1s 23ms/step - loss: 0.0079 - accuracy: 0.9971 - val_loss: 0.4576 - val_accuracy: 0.9197
Epoch 11/50
32/32 [==============================] - 1s 23ms/step - loss: 0.0163 - accuracy: 0.9951 - val_loss: 0.3831 - val_accuracy: 0.9279
Epoch 12/50
32/32 [==============================] - 1s 23ms/step - loss: 0.0222 - accuracy: 0.9922 - val_loss: 0.4183 - val_accuracy: 0.9205
Epoch 13/50
32/32 [==============================] - 1s 23ms/step - loss: 0.0407 - accuracy: 0.9897 - val_loss: 0.6182 - val_accuracy: 0.8940
Epoch 14/50
32/32 [==============================] - 1s 23ms/step - loss: 0.0641 - accuracy: 0.9819 - val_loss: 0.6043 - val_accuracy: 0.8929
Epoch 15/50
32/32 [==============================] - 1s 24ms/step - loss: 0.0640 - accuracy: 0.9839 - val_loss: 0.4911 - val_accuracy: 0.9118
Epoch 16/50
32/32 [==============================] - 1s 23ms/step - loss: 0.0798 - accuracy: 0.9780 - val_loss: 0.5456 - val_accuracy: 0.9069
Epoch 17/50
32/32 [==============================] - 1s 23ms/step - loss: 0.0807 - accuracy: 0.9741 - val_loss: 0.4527 - val_accuracy: 0.9160
Epoch 18/50
32/32 [==============================] - 1s 23ms/step - loss: 0.0609 - accuracy: 0.9844 - val_loss: 0.6052 - val_accuracy: 0.9029
Epoch 19/50
32/32 [==============================] - 1s 23ms/step - loss: 0.0330 - accuracy: 0.9907 - val_loss: 0.4286 - val_accuracy: 0.9210
Epoch 20/50
32/32 [==============================] - 1s 23ms/step - loss: 0.0094 - accuracy: 0.9985 - val_loss: 0.4022 - val_accuracy: 0.9283
Epoch 21/50
32/32 [==============================] - 1s 22ms/step - loss: 0.0024 - accuracy: 0.9990 - val_loss: 0.3571 - val_accuracy: 0.9370
Epoch 22/50
32/32 [==============================] - 1s 23ms/step - loss: 3.6756e-04 - accuracy: 1.0000 - val_loss: 0.3663 - val_accuracy: 0.9374
Epoch 23/50
32/32 [==============================] - 1s 24ms/step - loss: 1.2456e-04 - accuracy: 1.0000 - val_loss: 0.3668 - val_accuracy: 0.9385
Epoch 24/50
32/32 [==============================] - 1s 24ms/step - loss: 8.0307e-05 - accuracy: 1.0000 - val_loss: 0.3704 - val_accuracy: 0.9383
Epoch 25/50
32/32 [==============================] - 1s 23ms/step - loss: 6.0402e-05 - accuracy: 1.0000 - val_loss: 0.3741 - val_accuracy: 0.9391
Epoch 26/50
32/32 [==============================] - 1s 23ms/step - loss: 4.7941e-05 - accuracy: 1.0000 - val_loss: 0.3785 - val_accuracy: 0.9390
Epoch 27/50
32/32 [==============================] - 1s 23ms/step - loss: 3.9313e-05 - accuracy: 1.0000 - val_loss: 0.3831 - val_accuracy: 0.9395
Epoch 28/50
32/32 [==============================] - 1s 23ms/step - loss: 3.2839e-05 - accuracy: 1.0000 - val_loss: 0.3871 - val_accuracy: 0.9392
Epoch 29/50
32/32 [==============================] - 1s 23ms/step - loss: 2.7750e-05 - accuracy: 1.0000 - val_loss: 0.3913 - val_accuracy: 0.9392
Epoch 30/50
32/32 [==============================] - 1s 23ms/step - loss: 2.3494e-05 - accuracy: 1.0000 - val_loss: 0.3952 - val_accuracy: 0.9391
Epoch 31/50
32/32 [==============================] - 1s 22ms/step - loss: 2.0309e-05 - accuracy: 1.0000 - val_loss: 0.3989 - val_accuracy: 0.9393
Epoch 32/50
32/32 [==============================] - 1s 22ms/step - loss: 1.7693e-05 - accuracy: 1.0000 - val_loss: 0.4028 - val_accuracy: 0.9392
Epoch 33/50
32/32 [==============================] - 1s 23ms/step - loss: 1.5492e-05 - accuracy: 1.0000 - val_loss: 0.4064 - val_accuracy: 0.9390
Epoch 34/50
32/32 [==============================] - 1s 23ms/step - loss: 1.3694e-05 - accuracy: 1.0000 - val_loss: 0.4099 - val_accuracy: 0.9391
Epoch 35/50
32/32 [==============================] - 1s 23ms/step - loss: 1.2217e-05 - accuracy: 1.0000 - val_loss: 0.4129 - val_accuracy: 0.9390
Epoch 36/50
32/32 [==============================] - 1s 23ms/step - loss: 1.0953e-05 - accuracy: 1.0000 - val_loss: 0.4159 - val_accuracy: 0.9391
Epoch 37/50
32/32 [==============================] - 1s 23ms/step - loss: 9.8946e-06 - accuracy: 1.0000 - val_loss: 0.4189 - val_accuracy: 0.9390
Epoch 38/50
32/32 [==============================] - 1s 23ms/step - loss: 8.9504e-06 - accuracy: 1.0000 - val_loss: 0.4216 - val_accuracy: 0.9392
Epoch 39/50
32/32 [==============================] - 1s 22ms/step - loss: 8.1821e-06 - accuracy: 1.0000 - val_loss: 0.4244 - val_accuracy: 0.9391
Epoch 40/50
32/32 [==============================] - 1s 23ms/step - loss: 7.5098e-06 - accuracy: 1.0000 - val_loss: 0.4268 - val_accuracy: 0.9395
Epoch 41/50
32/32 [==============================] - 1s 23ms/step - loss: 6.8791e-06 - accuracy: 1.0000 - val_loss: 0.4291 - val_accuracy: 0.9394
Epoch 42/50
32/32 [==============================] - 1s 23ms/step - loss: 6.3821e-06 - accuracy: 1.0000 - val_loss: 0.4312 - val_accuracy: 0.9394
Epoch 43/50
32/32 [==============================] - 1s 22ms/step - loss: 5.8990e-06 - accuracy: 1.0000 - val_loss: 0.4338 - val_accuracy: 0.9393
Epoch 44/50
32/32 [==============================] - 1s 23ms/step - loss: 5.4862e-06 - accuracy: 1.0000 - val_loss: 0.4359 - val_accuracy: 0.9393
Epoch 45/50
32/32 [==============================] - 1s 23ms/step - loss: 5.1077e-06 - accuracy: 1.0000 - val_loss: 0.4382 - val_accuracy: 0.9389
Epoch 46/50
32/32 [==============================] - 1s 22ms/step - loss: 4.7791e-06 - accuracy: 1.0000 - val_loss: 0.4401 - val_accuracy: 0.9390
Epoch 47/50
32/32 [==============================] - 1s 22ms/step - loss: 4.4674e-06 - accuracy: 1.0000 - val_loss: 0.4421 - val_accuracy: 0.9390
Epoch 48/50
32/32 [==============================] - 1s 23ms/step - loss: 4.1864e-06 - accuracy: 1.0000 - val_loss: 0.4440 - val_accuracy: 0.9390
Epoch 49/50
32/32 [==============================] - 1s 23ms/step - loss: 3.9505e-06 - accuracy: 1.0000 - val_loss: 0.4460 - val_accuracy: 0.9388
Epoch 50/50
32/32 [==============================] - 1s 24ms/step - loss: 3.7084e-06 - accuracy: 1.0000 - val_loss: 0.4478 - val_accuracy: 0.9389
acc = history_non_augment.history['accuracy']
val_acc = history_non_augment.history['val_accuracy']

loss = history_non_augment.history['loss']
val_loss = history_non_augment.history['val_loss']

plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()),1.1])
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([-0.1,1.0])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()

输出:

5.3 使用增强后的数据进行模型训练

model_with_augment = bulidModel()
history_with_augment = model_with_augment.fit(augmented_train_batches,
                                            epochs=50,
                                           validation_data=validation_batches)

输出:

Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
flatten_3 (Flatten)          (None, 784)               0         
_________________________________________________________________
dense_9 (Dense)              (None, 4096)              3215360   
_________________________________________________________________
dense_10 (Dense)             (None, 4096)              16781312  
_________________________________________________________________
dense_11 (Dense)             (None, 10)                40970     
=================================================================
Total params: 20,037,642
Trainable params: 20,037,642
Non-trainable params: 0
_________________________________________________________________
Epoch 1/50
32/32 [==============================] - 1s 27ms/step - loss: 2.3455 - accuracy: 0.3154 - val_loss: 1.1609 - val_accuracy: 0.6739
Epoch 2/50
32/32 [==============================] - 1s 23ms/step - loss: 1.3705 - accuracy: 0.5386 - val_loss: 0.7378 - val_accuracy: 0.7757
Epoch 3/50
32/32 [==============================] - 1s 23ms/step - loss: 0.9722 - accuracy: 0.6851 - val_loss: 0.4942 - val_accuracy: 0.8571
Epoch 4/50
32/32 [==============================] - 1s 23ms/step - loss: 0.7568 - accuracy: 0.7441 - val_loss: 0.3969 - val_accuracy: 0.8749
Epoch 5/50
32/32 [==============================] - 1s 23ms/step - loss: 0.6670 - accuracy: 0.7891 - val_loss: 0.3878 - val_accuracy: 0.8604
Epoch 6/50
32/32 [==============================] - 1s 23ms/step - loss: 0.6065 - accuracy: 0.8022 - val_loss: 0.3017 - val_accuracy: 0.9055
Epoch 7/50
32/32 [==============================] - 1s 24ms/step - loss: 0.5568 - accuracy: 0.8057 - val_loss: 0.2845 - val_accuracy: 0.9100
Epoch 8/50
32/32 [==============================] - 1s 24ms/step - loss: 0.5136 - accuracy: 0.8291 - val_loss: 0.2541 - val_accuracy: 0.9200
Epoch 9/50
32/32 [==============================] - 1s 25ms/step - loss: 0.4372 - accuracy: 0.8574 - val_loss: 0.2279 - val_accuracy: 0.9310
Epoch 10/50
32/32 [==============================] - 1s 24ms/step - loss: 0.4240 - accuracy: 0.8608 - val_loss: 0.2220 - val_accuracy: 0.9319
Epoch 11/50
32/32 [==============================] - 1s 25ms/step - loss: 0.4186 - accuracy: 0.8589 - val_loss: 0.2300 - val_accuracy: 0.9290
Epoch 12/50
32/32 [==============================] - 1s 25ms/step - loss: 0.3784 - accuracy: 0.8706 - val_loss: 0.2086 - val_accuracy: 0.9353
Epoch 13/50
32/32 [==============================] - 1s 24ms/step - loss: 0.3523 - accuracy: 0.8838 - val_loss: 0.2269 - val_accuracy: 0.9270
Epoch 14/50
32/32 [==============================] - 1s 24ms/step - loss: 0.3586 - accuracy: 0.8813 - val_loss: 0.2001 - val_accuracy: 0.9380
Epoch 15/50
32/32 [==============================] - 1s 24ms/step - loss: 0.3613 - accuracy: 0.8799 - val_loss: 0.2310 - val_accuracy: 0.9265
Epoch 16/50
32/32 [==============================] - 1s 24ms/step - loss: 0.3332 - accuracy: 0.8901 - val_loss: 0.2240 - val_accuracy: 0.9276
Epoch 17/50
32/32 [==============================] - 1s 23ms/step - loss: 0.2640 - accuracy: 0.9121 - val_loss: 0.1797 - val_accuracy: 0.9417
Epoch 18/50
32/32 [==============================] - 1s 23ms/step - loss: 0.2956 - accuracy: 0.9062 - val_loss: 0.2051 - val_accuracy: 0.9353
Epoch 19/50
32/32 [==============================] - 1s 23ms/step - loss: 0.3339 - accuracy: 0.8887 - val_loss: 0.1981 - val_accuracy: 0.9368
Epoch 20/50
32/32 [==============================] - 1s 23ms/step - loss: 0.3056 - accuracy: 0.9053 - val_loss: 0.2208 - val_accuracy: 0.9329
Epoch 21/50
32/32 [==============================] - 1s 23ms/step - loss: 0.2541 - accuracy: 0.9126 - val_loss: 0.1700 - val_accuracy: 0.9453
Epoch 22/50
32/32 [==============================] - 1s 23ms/step - loss: 0.2389 - accuracy: 0.9248 - val_loss: 0.1752 - val_accuracy: 0.9452
Epoch 23/50
32/32 [==============================] - 1s 23ms/step - loss: 0.2478 - accuracy: 0.9180 - val_loss: 0.1723 - val_accuracy: 0.9461
Epoch 24/50
32/32 [==============================] - 1s 24ms/step - loss: 0.2275 - accuracy: 0.9292 - val_loss: 0.2045 - val_accuracy: 0.9379
Epoch 25/50
32/32 [==============================] - 1s 25ms/step - loss: 0.2495 - accuracy: 0.9185 - val_loss: 0.1836 - val_accuracy: 0.9428
Epoch 26/50
32/32 [==============================] - 1s 24ms/step - loss: 0.2581 - accuracy: 0.9199 - val_loss: 0.1581 - val_accuracy: 0.9507
Epoch 27/50
32/32 [==============================] - 1s 23ms/step - loss: 0.2552 - accuracy: 0.9170 - val_loss: 0.1798 - val_accuracy: 0.9436
Epoch 28/50
32/32 [==============================] - 1s 23ms/step - loss: 0.2268 - accuracy: 0.9336 - val_loss: 0.1763 - val_accuracy: 0.9438
Epoch 29/50
32/32 [==============================] - 1s 23ms/step - loss: 0.2276 - accuracy: 0.9321 - val_loss: 0.1668 - val_accuracy: 0.9472
Epoch 30/50
32/32 [==============================] - 1s 24ms/step - loss: 0.1971 - accuracy: 0.9395 - val_loss: 0.1537 - val_accuracy: 0.9503
Epoch 31/50
32/32 [==============================] - 1s 24ms/step - loss: 0.2141 - accuracy: 0.9341 - val_loss: 0.1570 - val_accuracy: 0.9506
Epoch 32/50
32/32 [==============================] - 1s 23ms/step - loss: 0.1763 - accuracy: 0.9434 - val_loss: 0.1516 - val_accuracy: 0.9530
Epoch 33/50
32/32 [==============================] - 1s 24ms/step - loss: 0.1866 - accuracy: 0.9355 - val_loss: 0.1556 - val_accuracy: 0.9536
Epoch 34/50
32/32 [==============================] - 1s 23ms/step - loss: 0.1979 - accuracy: 0.9341 - val_loss: 0.1811 - val_accuracy: 0.9450
Epoch 35/50
32/32 [==============================] - 1s 23ms/step - loss: 0.1973 - accuracy: 0.9390 - val_loss: 0.1626 - val_accuracy: 0.9515
Epoch 36/50
32/32 [==============================] - 1s 23ms/step - loss: 0.2121 - accuracy: 0.9263 - val_loss: 0.1596 - val_accuracy: 0.9521
Epoch 37/50
32/32 [==============================] - 1s 22ms/step - loss: 0.1693 - accuracy: 0.9429 - val_loss: 0.1567 - val_accuracy: 0.9517
Epoch 38/50
32/32 [==============================] - 1s 25ms/step - loss: 0.2130 - accuracy: 0.9336 - val_loss: 0.1511 - val_accuracy: 0.9536
Epoch 39/50
32/32 [==============================] - 1s 24ms/step - loss: 0.1854 - accuracy: 0.9409 - val_loss: 0.1744 - val_accuracy: 0.9451
Epoch 40/50
32/32 [==============================] - 1s 23ms/step - loss: 0.1764 - accuracy: 0.9399 - val_loss: 0.1663 - val_accuracy: 0.9492
Epoch 41/50
32/32 [==============================] - 1s 23ms/step - loss: 0.1800 - accuracy: 0.9438 - val_loss: 0.1676 - val_accuracy: 0.9481
Epoch 42/50
32/32 [==============================] - 1s 23ms/step - loss: 0.2094 - accuracy: 0.9292 - val_loss: 0.1614 - val_accuracy: 0.9520
Epoch 43/50
32/32 [==============================] - 1s 23ms/step - loss: 0.1776 - accuracy: 0.9370 - val_loss: 0.1511 - val_accuracy: 0.9566
Epoch 44/50
32/32 [==============================] - 1s 23ms/step - loss: 0.1476 - accuracy: 0.9526 - val_loss: 0.1725 - val_accuracy: 0.9498
Epoch 45/50
32/32 [==============================] - 1s 23ms/step - loss: 0.1537 - accuracy: 0.9482 - val_loss: 0.1611 - val_accuracy: 0.9509
Epoch 46/50
32/32 [==============================] - 1s 24ms/step - loss: 0.1292 - accuracy: 0.9541 - val_loss: 0.1543 - val_accuracy: 0.9544
Epoch 47/50
32/32 [==============================] - 1s 24ms/step - loss: 0.1605 - accuracy: 0.9458 - val_loss: 0.1591 - val_accuracy: 0.9529
Epoch 48/50
32/32 [==============================] - 1s 24ms/step - loss: 0.2012 - accuracy: 0.9390 - val_loss: 0.1858 - val_accuracy: 0.9453
Epoch 49/50
32/32 [==============================] - 1s 25ms/step - loss: 0.2423 - accuracy: 0.9185 - val_loss: 0.1525 - val_accuracy: 0.9553
Epoch 50/50
32/32 [==============================] - 1s 24ms/step - loss: 0.1835 - accuracy: 0.9434 - val_loss: 0.1819 - val_accuracy: 0.9498
acc = history_with_augment.history['accuracy']
val_acc = history_with_augment.history['val_accuracy']

loss = history_with_augment.history['loss']
val_loss = history_with_augment.history['val_loss']

plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()),1.1])
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([-0.1,1.0])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()

输出:

6. 总结

用增强后的数据训练的模型在验证集上的准确率为94.98%,而用未增强的数据训练的模型验证集上的准确率为93.89%。使用数据增强的结果稍优于未使用数据增强的模型。

 

 

 

  • 8
    点赞
  • 41
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值