本案例使用tf.image完成图像操作和预处理的过程,数据增强是防止过拟合的最常用的手段。
1. 导入所需的库
import tensorflow as tf
import numpy as np
import tensorflow_datasets as tfds
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
from PIL import Image
for i in [tf,np,tfds,mpl]:
print(i.__name__,": ",i.__version__,sep="")
输出:
tensorflow: 2.2.0
numpy: 1.17.4
tensorflow_datasets: 3.1.0
matplotlib: 3.1.2
2. 下载案例图像
image_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg"
image_path = tf.keras.utils.get_file("cat.jpg", image_url,cache_dir="./")
Image.open(image_path)
输出:
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg
24576/17858 [=========================================] - 0s 3us/step
image_string = tf.io.read_file(image_path)
image = tf.image.decode_jpeg(image_string,channels=3)
image.shape
输出:
TensorShape([213, 320, 3])
定义一个用于比对原始图像和数据增强后图像的函数
def visualize(original, augmented):
plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
plt.title("Original image")
#plt.axis("off") # 关闭坐标轴显示
plt.imshow(original)
plt.subplot(1,2,2)
plt.title("Augmented image")
#plt.axis("off") # 关闭坐标轴显示
plt.imshow(augmented)
plt.tight_layout()
plt.show()
visualize(image,image)
输出:
3. 单个图像数据增强
3.1 翻转图像(垂直和水平)
flipped_h = tf.image.flip_left_right(image)
flipped_v = tf.image.flip_up_down(image)
visualize(image,flipped_h)
visualize(image,flipped_v)
输出:
3.2 图像灰阶化
图像灰阶化是将彩色图像转变成灰度图像,通道数由3变为1。
grayscaled = tf.image.rgb_to_grayscale(image)
print(grayscaled.shape)
print(tf.squeeze(grayscaled).shape)
输出:
(213, 320, 1)
(213, 320)
tf.squeeze()删除掉那些维度为1的维。
visualize(image, tf.squeeze(grayscaled))
输出:
3.3 改变图像饱和度
saturated_3 = tf.image.adjust_saturation(image,3)
saturated_8 = tf.image.adjust_saturation(image,8)
visualize(image, saturated_3)
visualize(image, saturated_8)
输出:
3.4 改变图像亮度
bright_0 = tf.image.adjust_brightness(image, 0)
bright_5 = tf.image.adjust_brightness(image, 0.5)
bright_8 = tf.image.adjust_brightness(image, 0.8)
bright_10 = tf.image.adjust_brightness(image, 1)
visualize(image, bright_0)
visualize(image, bright_5)
visualize(image, bright_8)
visualize(image, bright_10)
输出:
3.5 图像旋转
rotated_90 = tf.image.rot90(image)
rotated_180 = tf.image.rot90(rotated_90)
rotated_270 = tf.image.rot90(rotated_180) # 相当于上下翻转
rotated_360 = tf.image.rot90(rotated_270)
visualize(image, rotated_90)
visualize(image, rotated_180)
visualize(image, rotated_270)
visualize(image, rotated_360)
输出:
3.6 图像裁剪
# central_fraction取值范围为(0,1]
cropped_1 = tf.image.central_crop(image, central_fraction=0.1)
cropped_5 = tf.image.central_crop(image, central_fraction=0.5)
cropped_8 = tf.image.central_crop(image, central_fraction=0.8)
cropped_10 = tf.image.central_crop(image, central_fraction=1)
for i in [cropped_1, cropped_5, cropped_8, cropped_10]:
visualize(image, i)
输出:
4. 数据集数据增强
4.1 下载并导入数据集
dataset, info = tfds.load("mnist",as_supervised=True, with_info=True)
train_dataset, test_dataset = dataset["train"],dataset["test"]
输出:
Downloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to C:\Users\my-pc\tensorflow_datasets\mnist\3.0.1...
Dl Completed...: 100%
4/4 [00:34<00:00, 8.66s/ file]
Dataset mnist downloaded and prepared to C:\Users\my-pc\tensorflow_datasets\mnist\3.0.1. Subsequent calls will reuse this data.
num_train_examples = info.splits["train"].num_examples
num_train_examples
输出:
60000
4.2 对数据集进行数据增强
def convert(image, label):
image = tf.image.convert_image_dtype(image, tf.float32)
return image, label
def augment(image, label):
image, label = convert(image, label)
#image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize_with_crop_or_pad(image, 34,34) # 四周各加3像素
image = tf.image.random_crop(image, size=[28,28,1]) # 随机裁剪成28*28大小
image = tf.image.random_brightness(image, max_delta=0.5) # 随机增加亮度
return image, label
batch_size = 64
num_examples = 2048 # 使用较少数据量,以展现出过拟合
augmented_train_batches = (train_dataset
.take(num_examples)
.cache()
.shuffle(num_train_examples//4)
.map(augment, num_parallel_calls=tf.data.experimental.AUTOTUNE)
.batch(batch_size)
.prefetch(tf.data.experimental.AUTOTUNE))
non_augmented_train_batches = (train_dataset
.take(num_examples)
.cache()
.shuffle(num_train_examples//4)
.map(convert, num_parallel_calls=tf.data.experimental.AUTOTUNE)
.batch(batch_size)
.prefetch(tf.data.experimental.AUTOTUNE))
validation_batches = (test_dataset
.map(convert, num_parallel_calls=tf.data.experimental.AUTOTUNE)
.batch(2*batch_size))
5. 构建模型并训练
5.1 模型构建函数
def bulidModel():
model = tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape=(28,28,1)),
tf.keras.layers.Dense(4096, activation="relu"),
tf.keras.layers.Dense(4096, activation="relu"),
tf.keras.layers.Dense(10)
])
model.summary()
model.compile(optimizer="adam",
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=["accuracy"])
return model
5.2 使用未增强的数据进行模型训练
model_non_augment = bulidModel()
history_non_augment = model_non_augment.fit(non_augmented_train_batches,
epochs=50,
validation_data=validation_batches)
输出:
Model: "sequential_2"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
flatten_2 (Flatten) (None, 784) 0
_________________________________________________________________
dense_6 (Dense) (None, 4096) 3215360
_________________________________________________________________
dense_7 (Dense) (None, 4096) 16781312
_________________________________________________________________
dense_8 (Dense) (None, 10) 40970
=================================================================
Total params: 20,037,642
Trainable params: 20,037,642
Non-trainable params: 0
_________________________________________________________________
Epoch 1/50
32/32 [==============================] - 1s 31ms/step - loss: 0.7841 - accuracy: 0.7520 - val_loss: 0.3936 - val_accuracy: 0.8803
Epoch 2/50
32/32 [==============================] - 1s 23ms/step - loss: 0.1908 - accuracy: 0.9399 - val_loss: 0.4234 - val_accuracy: 0.8898
Epoch 3/50
32/32 [==============================] - 1s 23ms/step - loss: 0.0936 - accuracy: 0.9712 - val_loss: 0.2674 - val_accuracy: 0.9270
Epoch 4/50
32/32 [==============================] - 1s 23ms/step - loss: 0.0475 - accuracy: 0.9878 - val_loss: 0.2828 - val_accuracy: 0.9247
Epoch 5/50
32/32 [==============================] - 1s 22ms/step - loss: 0.0370 - accuracy: 0.9863 - val_loss: 0.3070 - val_accuracy: 0.9245
Epoch 6/50
32/32 [==============================] - 1s 23ms/step - loss: 0.0551 - accuracy: 0.9819 - val_loss: 0.3165 - val_accuracy: 0.9264
Epoch 7/50
32/32 [==============================] - 1s 22ms/step - loss: 0.0233 - accuracy: 0.9951 - val_loss: 0.4092 - val_accuracy: 0.9121
Epoch 8/50
32/32 [==============================] - 1s 23ms/step - loss: 0.0193 - accuracy: 0.9946 - val_loss: 0.3974 - val_accuracy: 0.9171
Epoch 9/50
32/32 [==============================] - 1s 23ms/step - loss: 0.0095 - accuracy: 0.9971 - val_loss: 0.3894 - val_accuracy: 0.9208
Epoch 10/50
32/32 [==============================] - 1s 23ms/step - loss: 0.0079 - accuracy: 0.9971 - val_loss: 0.4576 - val_accuracy: 0.9197
Epoch 11/50
32/32 [==============================] - 1s 23ms/step - loss: 0.0163 - accuracy: 0.9951 - val_loss: 0.3831 - val_accuracy: 0.9279
Epoch 12/50
32/32 [==============================] - 1s 23ms/step - loss: 0.0222 - accuracy: 0.9922 - val_loss: 0.4183 - val_accuracy: 0.9205
Epoch 13/50
32/32 [==============================] - 1s 23ms/step - loss: 0.0407 - accuracy: 0.9897 - val_loss: 0.6182 - val_accuracy: 0.8940
Epoch 14/50
32/32 [==============================] - 1s 23ms/step - loss: 0.0641 - accuracy: 0.9819 - val_loss: 0.6043 - val_accuracy: 0.8929
Epoch 15/50
32/32 [==============================] - 1s 24ms/step - loss: 0.0640 - accuracy: 0.9839 - val_loss: 0.4911 - val_accuracy: 0.9118
Epoch 16/50
32/32 [==============================] - 1s 23ms/step - loss: 0.0798 - accuracy: 0.9780 - val_loss: 0.5456 - val_accuracy: 0.9069
Epoch 17/50
32/32 [==============================] - 1s 23ms/step - loss: 0.0807 - accuracy: 0.9741 - val_loss: 0.4527 - val_accuracy: 0.9160
Epoch 18/50
32/32 [==============================] - 1s 23ms/step - loss: 0.0609 - accuracy: 0.9844 - val_loss: 0.6052 - val_accuracy: 0.9029
Epoch 19/50
32/32 [==============================] - 1s 23ms/step - loss: 0.0330 - accuracy: 0.9907 - val_loss: 0.4286 - val_accuracy: 0.9210
Epoch 20/50
32/32 [==============================] - 1s 23ms/step - loss: 0.0094 - accuracy: 0.9985 - val_loss: 0.4022 - val_accuracy: 0.9283
Epoch 21/50
32/32 [==============================] - 1s 22ms/step - loss: 0.0024 - accuracy: 0.9990 - val_loss: 0.3571 - val_accuracy: 0.9370
Epoch 22/50
32/32 [==============================] - 1s 23ms/step - loss: 3.6756e-04 - accuracy: 1.0000 - val_loss: 0.3663 - val_accuracy: 0.9374
Epoch 23/50
32/32 [==============================] - 1s 24ms/step - loss: 1.2456e-04 - accuracy: 1.0000 - val_loss: 0.3668 - val_accuracy: 0.9385
Epoch 24/50
32/32 [==============================] - 1s 24ms/step - loss: 8.0307e-05 - accuracy: 1.0000 - val_loss: 0.3704 - val_accuracy: 0.9383
Epoch 25/50
32/32 [==============================] - 1s 23ms/step - loss: 6.0402e-05 - accuracy: 1.0000 - val_loss: 0.3741 - val_accuracy: 0.9391
Epoch 26/50
32/32 [==============================] - 1s 23ms/step - loss: 4.7941e-05 - accuracy: 1.0000 - val_loss: 0.3785 - val_accuracy: 0.9390
Epoch 27/50
32/32 [==============================] - 1s 23ms/step - loss: 3.9313e-05 - accuracy: 1.0000 - val_loss: 0.3831 - val_accuracy: 0.9395
Epoch 28/50
32/32 [==============================] - 1s 23ms/step - loss: 3.2839e-05 - accuracy: 1.0000 - val_loss: 0.3871 - val_accuracy: 0.9392
Epoch 29/50
32/32 [==============================] - 1s 23ms/step - loss: 2.7750e-05 - accuracy: 1.0000 - val_loss: 0.3913 - val_accuracy: 0.9392
Epoch 30/50
32/32 [==============================] - 1s 23ms/step - loss: 2.3494e-05 - accuracy: 1.0000 - val_loss: 0.3952 - val_accuracy: 0.9391
Epoch 31/50
32/32 [==============================] - 1s 22ms/step - loss: 2.0309e-05 - accuracy: 1.0000 - val_loss: 0.3989 - val_accuracy: 0.9393
Epoch 32/50
32/32 [==============================] - 1s 22ms/step - loss: 1.7693e-05 - accuracy: 1.0000 - val_loss: 0.4028 - val_accuracy: 0.9392
Epoch 33/50
32/32 [==============================] - 1s 23ms/step - loss: 1.5492e-05 - accuracy: 1.0000 - val_loss: 0.4064 - val_accuracy: 0.9390
Epoch 34/50
32/32 [==============================] - 1s 23ms/step - loss: 1.3694e-05 - accuracy: 1.0000 - val_loss: 0.4099 - val_accuracy: 0.9391
Epoch 35/50
32/32 [==============================] - 1s 23ms/step - loss: 1.2217e-05 - accuracy: 1.0000 - val_loss: 0.4129 - val_accuracy: 0.9390
Epoch 36/50
32/32 [==============================] - 1s 23ms/step - loss: 1.0953e-05 - accuracy: 1.0000 - val_loss: 0.4159 - val_accuracy: 0.9391
Epoch 37/50
32/32 [==============================] - 1s 23ms/step - loss: 9.8946e-06 - accuracy: 1.0000 - val_loss: 0.4189 - val_accuracy: 0.9390
Epoch 38/50
32/32 [==============================] - 1s 23ms/step - loss: 8.9504e-06 - accuracy: 1.0000 - val_loss: 0.4216 - val_accuracy: 0.9392
Epoch 39/50
32/32 [==============================] - 1s 22ms/step - loss: 8.1821e-06 - accuracy: 1.0000 - val_loss: 0.4244 - val_accuracy: 0.9391
Epoch 40/50
32/32 [==============================] - 1s 23ms/step - loss: 7.5098e-06 - accuracy: 1.0000 - val_loss: 0.4268 - val_accuracy: 0.9395
Epoch 41/50
32/32 [==============================] - 1s 23ms/step - loss: 6.8791e-06 - accuracy: 1.0000 - val_loss: 0.4291 - val_accuracy: 0.9394
Epoch 42/50
32/32 [==============================] - 1s 23ms/step - loss: 6.3821e-06 - accuracy: 1.0000 - val_loss: 0.4312 - val_accuracy: 0.9394
Epoch 43/50
32/32 [==============================] - 1s 22ms/step - loss: 5.8990e-06 - accuracy: 1.0000 - val_loss: 0.4338 - val_accuracy: 0.9393
Epoch 44/50
32/32 [==============================] - 1s 23ms/step - loss: 5.4862e-06 - accuracy: 1.0000 - val_loss: 0.4359 - val_accuracy: 0.9393
Epoch 45/50
32/32 [==============================] - 1s 23ms/step - loss: 5.1077e-06 - accuracy: 1.0000 - val_loss: 0.4382 - val_accuracy: 0.9389
Epoch 46/50
32/32 [==============================] - 1s 22ms/step - loss: 4.7791e-06 - accuracy: 1.0000 - val_loss: 0.4401 - val_accuracy: 0.9390
Epoch 47/50
32/32 [==============================] - 1s 22ms/step - loss: 4.4674e-06 - accuracy: 1.0000 - val_loss: 0.4421 - val_accuracy: 0.9390
Epoch 48/50
32/32 [==============================] - 1s 23ms/step - loss: 4.1864e-06 - accuracy: 1.0000 - val_loss: 0.4440 - val_accuracy: 0.9390
Epoch 49/50
32/32 [==============================] - 1s 23ms/step - loss: 3.9505e-06 - accuracy: 1.0000 - val_loss: 0.4460 - val_accuracy: 0.9388
Epoch 50/50
32/32 [==============================] - 1s 24ms/step - loss: 3.7084e-06 - accuracy: 1.0000 - val_loss: 0.4478 - val_accuracy: 0.9389
acc = history_non_augment.history['accuracy']
val_acc = history_non_augment.history['val_accuracy']
loss = history_non_augment.history['loss']
val_loss = history_non_augment.history['val_loss']
plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()),1.1])
plt.title('Training and Validation Accuracy')
plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([-0.1,1.0])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()
输出:
5.3 使用增强后的数据进行模型训练
model_with_augment = bulidModel()
history_with_augment = model_with_augment.fit(augmented_train_batches,
epochs=50,
validation_data=validation_batches)
输出:
Model: "sequential_3"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
flatten_3 (Flatten) (None, 784) 0
_________________________________________________________________
dense_9 (Dense) (None, 4096) 3215360
_________________________________________________________________
dense_10 (Dense) (None, 4096) 16781312
_________________________________________________________________
dense_11 (Dense) (None, 10) 40970
=================================================================
Total params: 20,037,642
Trainable params: 20,037,642
Non-trainable params: 0
_________________________________________________________________
Epoch 1/50
32/32 [==============================] - 1s 27ms/step - loss: 2.3455 - accuracy: 0.3154 - val_loss: 1.1609 - val_accuracy: 0.6739
Epoch 2/50
32/32 [==============================] - 1s 23ms/step - loss: 1.3705 - accuracy: 0.5386 - val_loss: 0.7378 - val_accuracy: 0.7757
Epoch 3/50
32/32 [==============================] - 1s 23ms/step - loss: 0.9722 - accuracy: 0.6851 - val_loss: 0.4942 - val_accuracy: 0.8571
Epoch 4/50
32/32 [==============================] - 1s 23ms/step - loss: 0.7568 - accuracy: 0.7441 - val_loss: 0.3969 - val_accuracy: 0.8749
Epoch 5/50
32/32 [==============================] - 1s 23ms/step - loss: 0.6670 - accuracy: 0.7891 - val_loss: 0.3878 - val_accuracy: 0.8604
Epoch 6/50
32/32 [==============================] - 1s 23ms/step - loss: 0.6065 - accuracy: 0.8022 - val_loss: 0.3017 - val_accuracy: 0.9055
Epoch 7/50
32/32 [==============================] - 1s 24ms/step - loss: 0.5568 - accuracy: 0.8057 - val_loss: 0.2845 - val_accuracy: 0.9100
Epoch 8/50
32/32 [==============================] - 1s 24ms/step - loss: 0.5136 - accuracy: 0.8291 - val_loss: 0.2541 - val_accuracy: 0.9200
Epoch 9/50
32/32 [==============================] - 1s 25ms/step - loss: 0.4372 - accuracy: 0.8574 - val_loss: 0.2279 - val_accuracy: 0.9310
Epoch 10/50
32/32 [==============================] - 1s 24ms/step - loss: 0.4240 - accuracy: 0.8608 - val_loss: 0.2220 - val_accuracy: 0.9319
Epoch 11/50
32/32 [==============================] - 1s 25ms/step - loss: 0.4186 - accuracy: 0.8589 - val_loss: 0.2300 - val_accuracy: 0.9290
Epoch 12/50
32/32 [==============================] - 1s 25ms/step - loss: 0.3784 - accuracy: 0.8706 - val_loss: 0.2086 - val_accuracy: 0.9353
Epoch 13/50
32/32 [==============================] - 1s 24ms/step - loss: 0.3523 - accuracy: 0.8838 - val_loss: 0.2269 - val_accuracy: 0.9270
Epoch 14/50
32/32 [==============================] - 1s 24ms/step - loss: 0.3586 - accuracy: 0.8813 - val_loss: 0.2001 - val_accuracy: 0.9380
Epoch 15/50
32/32 [==============================] - 1s 24ms/step - loss: 0.3613 - accuracy: 0.8799 - val_loss: 0.2310 - val_accuracy: 0.9265
Epoch 16/50
32/32 [==============================] - 1s 24ms/step - loss: 0.3332 - accuracy: 0.8901 - val_loss: 0.2240 - val_accuracy: 0.9276
Epoch 17/50
32/32 [==============================] - 1s 23ms/step - loss: 0.2640 - accuracy: 0.9121 - val_loss: 0.1797 - val_accuracy: 0.9417
Epoch 18/50
32/32 [==============================] - 1s 23ms/step - loss: 0.2956 - accuracy: 0.9062 - val_loss: 0.2051 - val_accuracy: 0.9353
Epoch 19/50
32/32 [==============================] - 1s 23ms/step - loss: 0.3339 - accuracy: 0.8887 - val_loss: 0.1981 - val_accuracy: 0.9368
Epoch 20/50
32/32 [==============================] - 1s 23ms/step - loss: 0.3056 - accuracy: 0.9053 - val_loss: 0.2208 - val_accuracy: 0.9329
Epoch 21/50
32/32 [==============================] - 1s 23ms/step - loss: 0.2541 - accuracy: 0.9126 - val_loss: 0.1700 - val_accuracy: 0.9453
Epoch 22/50
32/32 [==============================] - 1s 23ms/step - loss: 0.2389 - accuracy: 0.9248 - val_loss: 0.1752 - val_accuracy: 0.9452
Epoch 23/50
32/32 [==============================] - 1s 23ms/step - loss: 0.2478 - accuracy: 0.9180 - val_loss: 0.1723 - val_accuracy: 0.9461
Epoch 24/50
32/32 [==============================] - 1s 24ms/step - loss: 0.2275 - accuracy: 0.9292 - val_loss: 0.2045 - val_accuracy: 0.9379
Epoch 25/50
32/32 [==============================] - 1s 25ms/step - loss: 0.2495 - accuracy: 0.9185 - val_loss: 0.1836 - val_accuracy: 0.9428
Epoch 26/50
32/32 [==============================] - 1s 24ms/step - loss: 0.2581 - accuracy: 0.9199 - val_loss: 0.1581 - val_accuracy: 0.9507
Epoch 27/50
32/32 [==============================] - 1s 23ms/step - loss: 0.2552 - accuracy: 0.9170 - val_loss: 0.1798 - val_accuracy: 0.9436
Epoch 28/50
32/32 [==============================] - 1s 23ms/step - loss: 0.2268 - accuracy: 0.9336 - val_loss: 0.1763 - val_accuracy: 0.9438
Epoch 29/50
32/32 [==============================] - 1s 23ms/step - loss: 0.2276 - accuracy: 0.9321 - val_loss: 0.1668 - val_accuracy: 0.9472
Epoch 30/50
32/32 [==============================] - 1s 24ms/step - loss: 0.1971 - accuracy: 0.9395 - val_loss: 0.1537 - val_accuracy: 0.9503
Epoch 31/50
32/32 [==============================] - 1s 24ms/step - loss: 0.2141 - accuracy: 0.9341 - val_loss: 0.1570 - val_accuracy: 0.9506
Epoch 32/50
32/32 [==============================] - 1s 23ms/step - loss: 0.1763 - accuracy: 0.9434 - val_loss: 0.1516 - val_accuracy: 0.9530
Epoch 33/50
32/32 [==============================] - 1s 24ms/step - loss: 0.1866 - accuracy: 0.9355 - val_loss: 0.1556 - val_accuracy: 0.9536
Epoch 34/50
32/32 [==============================] - 1s 23ms/step - loss: 0.1979 - accuracy: 0.9341 - val_loss: 0.1811 - val_accuracy: 0.9450
Epoch 35/50
32/32 [==============================] - 1s 23ms/step - loss: 0.1973 - accuracy: 0.9390 - val_loss: 0.1626 - val_accuracy: 0.9515
Epoch 36/50
32/32 [==============================] - 1s 23ms/step - loss: 0.2121 - accuracy: 0.9263 - val_loss: 0.1596 - val_accuracy: 0.9521
Epoch 37/50
32/32 [==============================] - 1s 22ms/step - loss: 0.1693 - accuracy: 0.9429 - val_loss: 0.1567 - val_accuracy: 0.9517
Epoch 38/50
32/32 [==============================] - 1s 25ms/step - loss: 0.2130 - accuracy: 0.9336 - val_loss: 0.1511 - val_accuracy: 0.9536
Epoch 39/50
32/32 [==============================] - 1s 24ms/step - loss: 0.1854 - accuracy: 0.9409 - val_loss: 0.1744 - val_accuracy: 0.9451
Epoch 40/50
32/32 [==============================] - 1s 23ms/step - loss: 0.1764 - accuracy: 0.9399 - val_loss: 0.1663 - val_accuracy: 0.9492
Epoch 41/50
32/32 [==============================] - 1s 23ms/step - loss: 0.1800 - accuracy: 0.9438 - val_loss: 0.1676 - val_accuracy: 0.9481
Epoch 42/50
32/32 [==============================] - 1s 23ms/step - loss: 0.2094 - accuracy: 0.9292 - val_loss: 0.1614 - val_accuracy: 0.9520
Epoch 43/50
32/32 [==============================] - 1s 23ms/step - loss: 0.1776 - accuracy: 0.9370 - val_loss: 0.1511 - val_accuracy: 0.9566
Epoch 44/50
32/32 [==============================] - 1s 23ms/step - loss: 0.1476 - accuracy: 0.9526 - val_loss: 0.1725 - val_accuracy: 0.9498
Epoch 45/50
32/32 [==============================] - 1s 23ms/step - loss: 0.1537 - accuracy: 0.9482 - val_loss: 0.1611 - val_accuracy: 0.9509
Epoch 46/50
32/32 [==============================] - 1s 24ms/step - loss: 0.1292 - accuracy: 0.9541 - val_loss: 0.1543 - val_accuracy: 0.9544
Epoch 47/50
32/32 [==============================] - 1s 24ms/step - loss: 0.1605 - accuracy: 0.9458 - val_loss: 0.1591 - val_accuracy: 0.9529
Epoch 48/50
32/32 [==============================] - 1s 24ms/step - loss: 0.2012 - accuracy: 0.9390 - val_loss: 0.1858 - val_accuracy: 0.9453
Epoch 49/50
32/32 [==============================] - 1s 25ms/step - loss: 0.2423 - accuracy: 0.9185 - val_loss: 0.1525 - val_accuracy: 0.9553
Epoch 50/50
32/32 [==============================] - 1s 24ms/step - loss: 0.1835 - accuracy: 0.9434 - val_loss: 0.1819 - val_accuracy: 0.9498
acc = history_with_augment.history['accuracy']
val_acc = history_with_augment.history['val_accuracy']
loss = history_with_augment.history['loss']
val_loss = history_with_augment.history['val_loss']
plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()),1.1])
plt.title('Training and Validation Accuracy')
plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([-0.1,1.0])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()
输出:
6. 总结
用增强后的数据训练的模型在验证集上的准确率为94.98%,而用未增强的数据训练的模型验证集上的准确率为93.89%。使用数据增强的结果稍优于未使用数据增强的模型。