1 AlexNet简介
AlexNet
在【ImageNet Classification with Deep Convolutional Neural Networks
】论文中提出的,并在ILSVRC-2012
获得第一名。AlexNet的特点:
- 相比较
LetNet-5
,更深的网络结构。 - 使用层叠的卷积层,即卷积层+卷积层+池化层来提取图像的特征。
- 使用
Dropout
抑制过拟合。 - 使用数据增强
Data Augmentation
抑制过拟合。 - 使用
Relu
替换之前的sigmoid
的作为激活函数。 - 多
GPU
训练。
整个网络采用的ReLu
非线性激活函数,可以参考深度学习之激活函数。论文给出下面的网络架构。
但是并不是很详细,下面的模型架构图讲的更清晰,下面我会一一进行介绍
总共有8
层,5
层卷积层+3
层全连接层,下面一张图来源网络,详细介绍每层的具体设置和参数。本文会根据这个模型的架构图来实现(稍微做调整)。
2 数据处理
数据集是由14
种花组成的,分别为康乃馨,杜鹃花,桂花,桃花,梅花,洛神花,牡丹,牵牛花,玫瑰,茉莉花,荷花,菊花,蒲公英,风信子
,数据集中图片大小尺寸不一致,我们需要统一处理为相同的尺寸[224,224]
。采用 Keras
中使用 ImageDataGenerator
进行图像增强处理,基本使用详见参考第十五章 用图像增强改善模型性能。
2.1 数据集基本设置
IMG_W = 224 # 定义裁剪的图片宽度
IMG_H = 224 # 定义裁剪的图片高度
CLASS = 14 # 图片的分类数
EPOCHS = 5 # 迭代周期
BATCH_SIZE = 64 # 批次大小
TRAIN_PATH = './data/data/train' # 训练集存放路径
TEST_PATH = './data/data/test' # 测试集存放路径
SAVE_PATH = './data/flower_selector' # 模型保存路径
LEARNING_RATE = 1e-4 # 学习率
DROPOUT_RATE = 0 # 抗拟合,不工作的神经网络百分比
2.2 ImageDataGenerator设置
train_datagen = ImageDataGenerator(
rotation_range=40, # 随机旋转度数
width_shift_range=0.2, # 随机水平平移
height_shift_range=0.2, # 随机竖直平移
rescale=1 / 255, # 数据归一化
shear_range=20, # 随机错切变换
zoom_range=0.2, # 随机放大
horizontal_flip=True, # 水平翻转
fill_mode='nearest', # 填充方式
)
2.3 归一化处理
test_datagen = ImageDataGenerator(
rescale=1 / 255, # 数据归一化
)
2.4 训练集
train_generator = train_datagen.flow_from_directory( # 设置训练集迭代器
TRAIN_PATH, # 训练集存放路径
target_size=(IMG_W, IMG_H), # 训练集图片尺寸
batch_size=BATCH_SIZE # 训练集批次
)
2.5 测试集
test_generator = test_datagen.flow_from_directory( # 设置测试集迭代器
TEST_PATH, # 测试集存放路径
target_size=(IMG_W, IMG_H), # 测试集图片尺寸
batch_size=BATCH_SIZE, # 测试集批次
)
将预测结果转成标签名称
num_to_char = dict((test_generator.class_indices.get(index), index) for index in test_generator.class_indices)
2.6 数据集预览
for index in range(9):
x, y = test_generator[index]
tile = num_to_char[tensorflow.argmax(y[index]).numpy()]
plt.subplot(330 + 1 + index)
plt.title(tile)
plt.imshow(x[:][index])
plt.show()
3 构建模型
class AlexNet:
def __init__(self):
self.model = Sequential()
# 模型构建
def train(self):
self.model.add(Conv2D(input_shape=(224, 224, 3),
kernel_size=(11, 11),
strides=4,
filters=96,
activation=tf.keras.activations.relu))
self.model.add(MaxPooling2D(pool_size=(3, 3), strides=2))
self.model.add(Conv2D(kernel_size=(5, 5),
strides=1,
filters=48,
activation=tf.keras.activations.relu))
self.model.add(MaxPooling2D(pool_size=(3, 3), strides=2))
self.model.add(Conv2D(kernel_size=(3, 3),
strides=1,
filters=128,
activation=tf.keras.activations.relu))
self.model.add(Conv2D(kernel_size=(3, 3),
strides=1,
filters=192,
activation=tf.keras.activations.relu))
self.model.add(Conv2D(kernel_size=(3, 3),
strides=1,
filters=192,
activation=tf.keras.activations.relu))
self.model.add(MaxPooling2D(pool_size=(3, 3), strides=2))
self.model.add(Flatten())
self.model.add(Dense(128, activation=tf.keras.activations.relu))
self.model.add(Dropout(0.5))
self.model.add(Dense(64, activation=tf.keras.activations.relu))
self.model.add(Dropout(0.5))
self.model.add(Dense(14, activation=tf.keras.activations.softmax))
print('model framework', self.model.summary())
self.model.compile(optimizer=tf.keras.optimizers.Adam(LEARNING_RATE),
loss=tf.keras.losses.CategoricalCrossentropy(),
metrics=['accuracy'])
try:
self.model.load_weights('{}.h5'.format(SAVE_PATH)) # 尝试读取训练好的模型,再次训练
print('model upload,start training!')
except:
print('not find model,start training') # 如果没有训练过的模型,则从头开始训练
checkpoint = ModelCheckpoint('{}.h5'.format(SAVE_PATH), monitor='accuracy', verbose=1, save_best_only=True, model='max')
callbacks_list = [checkpoint]
self.history = self.model.fit(
train_generator, # 训练集迭代器
steps_per_epoch=len(train_generator), # 每个周期需要迭代多少步(图片总量/批次大小=11200/64=175)
epochs=EPOCHS, # 迭代周期
verbose=1,
validation_data=test_generator, # 测试集迭代器
validation_steps=len(test_generator), # 测试集迭代多少步
callbacks=callbacks_list
)
def predict(self, x):
self.model.predict(x)
# 绘制训练历史记录
def plot_history(self):
plt.plot(self.history.history['loss'], label='loss')
plt.plot(self.history.history['val_loss'], label='val_loss')
plt.plot(self.history.history['accuracy'], label='accuracy')
plt.tit le('history')
plt.legend()
plt.show()
4 训练记录
由于自己电脑没有GPU
训练的太慢,只能断断续续的进行训练,这是其中一段训练历史记录。学了下Gooogle Colab,准备用它来训练。
5 预测结果
欢迎点赞收藏,下篇文章将模型移植到Android端,实现拍照识别鲜花!!!!!!