优达学城(Udacity)
Tensorflow Free Course lesson5 测试题(分类花朵图像)
代码:
引入库和文件下载整理
import os
import numpy as np
import glob
import shutil
import matplotlib.pyplot as plt
import tensorflow.python as tf
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator
_URL = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
zip_file = tf.keras.utils.get_file(origin=_URL,
fname="flower_photos.tgz",
extract=True)
# 若文件尚未在缓存中,则下载
base_dir = os.path.join(os.path.dirname(zip_file), 'flower_photos')
classes = ['roses', 'daisy', 'dandelion', 'sunflowers', 'tulips']
# C:\Users\Mry\.keras\datasets\flower_photos"
for cl in classes: #分类统计
img_path = os.path.join(base_dir, cl)
print(img_path)
images = glob.glob(img_path + '/*.jpg')
print("{}: {} Images".format(cl, len(images)))
train, val = images[:round(len(images) * 0.8)], images[round(len(images) * 0.8):]
for t in train: #划分训练集
if not os.path.exists(os.path.join(base_dir, 'train', cl)):
os.makedirs(os.path.join(base_dir, 'train', cl))
shutil.move(t, os.path.join(base_dir, 'train', cl))
for v in val: #划分验证集
if not os.path.exists(os.path.join(base_dir, 'val', cl)):
os.makedirs(os.path.join(base_dir, 'val', cl))
shutil.move(v, os.path.join(base_dir, 'val', cl))
train_dir = os.path.join(base_dir, 'train')
val_dir = os.path.join(base_dir, 'val')
batch_size = 100 #设置批大小
IMG_SHAPE = 150 #设置图像大小
数据扩充
-----------水平翻转
# 图形输出函数 用于输出变换之后的图像
def plotImages(images_arr):
fig, axes = plt.subplots(1, 5, figsize=(20, 20))
axes = axes.flatten()
for img, ax in zip(images_arr, axes):
ax.imshow(img)
plt.tight_layout()
plt.show()
# 随机水平翻转
image_gen = ImageDataGenerator(rescale=1. / 255, horizontal_flip=True)
train_data_gen = image_gen.flow_from_directory(directory=train_dir,
target_size=(IMG_SHAPE, IMG_SHAPE),
batch_size=batch_size,
shuffle=True)
augmented_images = [train_data_gen[0][0][0] for i in range(5)]
plotImages(augmented_images) # 展示第一个图片变换之后的效果
效果图展示:随机翻转(翻转或不反转)
---------随即旋转(45°)
image_gen = ImageDataGenerator(rescale=1. / 255, rotation_range=45)
train_data_gen = image_gen.flow_from_directory(directory=train_dir,
target_size=(IMG_SHAPE, IMG_SHAPE),
batch_size=batch_size,
shuffle=True)
augmented_images = [train_data_gen[0][0][0] for i in range(5)]
plotImages(augmented_images) # 展示第一个图片变换之后的效果
效果图展示:
--------缩放
image_gen = ImageDataGenerator(rescale=1. / 255, zoom_range=0.5)
train_data_gen = image_gen.flow_from_directory(directory=train_dir,
target_size=(IMG_SHAPE, IMG_SHAPE),
batch_size=batch_size,
shuffle=True)
augmented_images = [train_data_gen[0][0][0] for i in range(5)]
plotImages(augmented_images) # 展示第一个图片变换之后的效果
效果展示:
------各种随机
# 各种随机
image_gen_train = ImageDataGenerator(rescale=1. / 255, # 重新缩放
rotation_range=45, # 旋转
zoom_range=0.5, # 缩放
horizontal_flip=True, # 水平翻转
width_shift_range=0.15, # 宽度偏移
height_shift_range=0.15) # 高度偏移
train_data_gen = image_gen_train.flow_from_directory(batch_size=batch_size,
directory=train_dir,
target_size=(IMG_SHAPE, IMG_SHAPE),
shuffle=True,
class_mode='sparse')
augmented_images = [train_data_gen[0][0][0] for i in range(5)]
plotImages(augmented_images) # 展示第一个图片变换之后的效果
效果展示:
------处理测试集
# 调整测试集
image_gen_val = ImageDataGenerator(rescale=1. / 255)
val_data_gen = image_gen_val.flow_from_directory(batch_size=batch_size,
directory=val_dir,
target_size=(IMG_SHAPE, IMG_SHAPE),
class_mode='sparse')
模型构建和训练
# 建立模型
model = Sequential()
# 添加第一个卷积层和池化层
model.add(Conv2D(16, 3, padding='same', activation='relu', input_shape=(IMG_SHAPE, IMG_SHAPE, 3)))
model.add(MaxPooling2D(pool_size=(2, 2)))
# 添加第二个卷积层和池化层
model.add(Conv2D(32, 3, padding='same', activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
# 添加第三个卷积层和池化层
model.add(Conv2D(64, 2, padding='same', activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
# 添加全连接层
model.add(Flatten())
model.add(Dropout(0.2)) # 貌似是控制输出大小
model.add(Dense(512, activation='relu'))
# 输出层
model.add(Dropout(0.2))
model.add(Dense(5, activation='softmax'))
# 编译模型
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# 训练模型
epochs = 20
history = model.fit_generator(train_data_gen,
steps_per_epoch=int(np.ceil(train_data_gen.n / float(batch_size))),
epochs=epochs,
validation_data=val_data_gen,
validation_steps=int(np.ceil(val_data_gen.n / float(batch_size))))
绘制训练和测试 准确率和损失图
# 绘制训练和测试 准确率和损失图
print(history.history.keys())
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(epochs)
plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
运行数据
————————训练集准确率————————
[0.3488927, 0.4834753, 0.5614991, 0.58364564, 0.62521297,0.6289608, 0.6453152, 0.6545145, 0.6678024, 0.6684838,
0.68245316, 0.68688244, 0.68892676, 0.69914824, 0.7056218,0.69574106, 0.711414, 0.7144804, 0.71243614, 0.72231686,
0.7168654, 0.74241906, 0.7301533, 0.7512777, 0.7407155,0.74923337, 0.7649063, 0.7601363, 0.7478705, 0.750937,
0.750937, 0.75502557, 0.758092, 0.7560477, 0.76626915,0.7706985, 0.7798978, 0.76149917, 0.77206135, 0.78160137,
0.79011923, 0.79216355, 0.7744463, 0.79863715, 0.79693353,0.7826235, 0.80579215, 0.8068143, 0.80613285, 0.79693353,
0.8081772, 0.8146508, 0.8153322, 0.80545145, 0.8139693,0.81873935, 0.81908005, 0.81601363, 0.82487226, 0.82555366,
0.81873935, 0.8214651, 0.81873935, 0.839523, 0.8265758,0.8473595,0.8333901, 0.83918226, 0.83645654, 0.83577514,
0.8367973, 0.85621804, 0.8531516, 0.84463376, 0.84531516,0.8473595, 0.8528109, 0.8572402, 0.84293014, 0.8606474]
————————测试集准确率————————
[0.43265307, 0.5251701, 0.5891156, 0.6258503, 0.60408163, 0.62721086, 0.6244898, 0.662585, 0.68435377, 0.7047619,
0.66394556, 0.7047619, 0.6897959, 0.7020408, 0.6979592, 0.6993197, 0.72789115, 0.7306122, 0.71428573, 0.6952381,
0.722449, 0.74285716, 0.7210884, 0.72380954, 0.7360544, 0.74285716, 0.7360544, 0.7578231, 0.7306122, 0.75238097,
0.7360544, 0.7537415, 0.7482993, 0.7442177, 0.7537415, 0.7659864, 0.7578231, 0.75102043, 0.7537415, 0.74965984,
0.73741496, 0.7346939, 0.74965984, 0.7659864, 0.7442177, 0.7659864, 0.7578231, 0.7768707, 0.7632653, 0.76462585,
0.7632653, 0.77414966, 0.76870745, 0.74965984, 0.77823126, 0.7809524, 0.76462585, 0.75918365, 0.76870745, 0.7673469,
0.7714286, 0.7768707, 0.7823129, 0.78367347, 0.77959186, 0.77823126, 0.78639454, 0.7931973, 0.7768707, 0.7768707,
0.7945578, 0.77414966, 0.785034, 0.78911567, 0.77959186, 0.78775513, 0.7768707, 0.785034, 0.79727894, 0.7823129]
————————训练集loss————————
[1.5823648457226598, 1.1944781958142985, 1.1012410254454084, 1.03964230484207, 0.9662862563255089,
0.9517874297332114, 0.9106217200345846, 0.8875170186593545, 0.8618406035018739, 0.8548744114740947,
0.8273825055492797, 0.8046525978757738, 0.8108818937200907, 0.7936269302790575, 0.7615727042014595,
0.7822818707363772, 0.7454067135588274, 0.7267284239983437, 0.7281319564002133, 0.7204672338812274,
0.7229581823543958, 0.6699235792663678, 0.6970297495185618, 0.6569172825784861, 0.6764492847444251,
0.6521440903311897, 0.6310099546101064, 0.6226096797030057, 0.6459155098949217, 0.658994189417545,
0.6519654699850326, 0.622800934152262, 0.6306515725954631, 0.6136238323159681, 0.5983666666930945,
0.5873160181500436, 0.5857151501020337, 0.5894017121048558, 0.5703074351283111, 0.5615118014954628,
0.555702697834733, 0.5547649235904115, 0.5691021034014895, 0.5436795555956319, 0.5172360462772785,
0.5476756975232437, 0.5215101944730066, 0.5123172948165484, 0.5086335099129295, 0.5132814260418614,
0.5009467172257133, 0.478149976600374, 0.4940059999387959, 0.5166569145006054, 0.49103465366607224,
0.4847617827647792, 0.4673643915880681, 0.4921251035833277, 0.46098112313670236, 0.4568277422268915,
0.463003586260586, 0.4722275963430713, 0.45487462641636395, 0.42985527443317173, 0.4522078438349436,
0.4100336758911711, 0.4296528099951996, 0.43007276548640694, 0.4242795927154957, 0.44047689625717307,
0.4186704042518362, 0.3811385627499634, 0.40322202880703406, 0.41596289042516704, 0.39976311547386584,
0.3942081443901972, 0.39740706089204986, 0.3763422839584708, 0.4061678959726272, 0.38046372139880363]
————————测试集loss————————
[1.222036674618721, 1.0966251567006111, 1.0576136335730553, 0.9514805302023888, 0.9740455150604248,
0.8776477500796318, 0.9441980272531509, 0.8600848317146301, 0.7724986374378204, 0.7943335473537445,
0.854674756526947, 0.7276982888579369, 0.7962107062339783, 0.7357299774885178, 0.7221283726394176,
0.7294800542294979, 0.6808925196528435, 0.690061941742897, 0.7871735170483589, 0.7951554134488106,
0.6850615888834, 0.6963930949568748, 0.7047825902700424, 0.7068115174770355, 0.677982360124588,
0.6587852314114571, 0.6862229183316231, 0.6841201297938824, 0.729791484773159, 0.6491366103291512,
0.6675430610775948, 0.6344344541430473, 0.6665278077125549, 0.6845718398690224, 0.629774022847414,
0.6481835767626762, 0.6688399091362953, 0.677779957652092, 0.6494889855384827, 0.6832970231771469,
0.6830611228942871, 0.7100113481283188, 0.6574587896466255, 0.632441807538271, 0.6827348843216896,
0.6315791495144367, 0.6429122425615788, 0.62912168353796, 0.6635183691978455, 0.6787677183747292,
0.6551337093114853, 0.6567187458276749, 0.663389440625906, 0.7014291733503342, 0.6835674792528152,
0.6435204781591892, 0.6926406696438789, 0.7112061604857445, 0.6817730404436588, 0.6317573599517345,
0.6463879160583019, 0.6991499103605747, 0.6600395403802395, 0.6666725762188435, 0.6460988968610764,
0.6549481004476547, 0.6569248586893082, 0.6410825401544571, 0.6676031276583672, 0.7397986575961113,
0.6627241857349873, 0.6820458173751831, 0.6711707413196564, 0.6281307749450207, 0.7054491937160492,
0.6480866856873035, 0.6436318531632423, 0.7016018405556679, 0.6524763368070126, 0.6591740623116493]
'''