Xception迁移学习:玉米叶片病害识别分类
- 数据集:来自网上公开的PlantVillage数据集中的玉米叶片部分。
- 运行环境:Tensorflow深度学习开源框架,选用Python 3.6.12作为编程语言。
本代码是自己查阅了很多博客代码最后根绝自己要用的数据集综合而成的,由于过于久远,不记得参考了哪些博客,这里就不放链接了。记录下来,便于自己以后查阅。也是刚入门的小白,欢迎大佬指教!
代码如下
1. 导入
import tensorflow as tf
import tensorflow.keras as keras
import matplotlib.pyplot as plt
import tensorflow.keras.preprocessing.image as image
import os as os
from tensorflow.keras.applications import Xception
from tensorflow.keras.layers import Dense,Flatten,GlobalAveragePooling2D,Dropout
from tensorflow.keras.models import Model,load_model
from tensorflow.keras.optimizers import SGD
2. 设置参数和路径
IMG_SIZE:输入图片的尺寸;
batch_size:每次读取图片的数量;
EPOCHS:训练轮次;
train_path:训练集路径;val_path:验证集路径。
IMG_SIZE = 150
batch_size = 16
EPOCHS=100
IMG_SHAPE = (IMG_SIZE, IMG_SIZE, 3)
train_path = 'D:/tmp/New Maize Data set/Train_maize'
val_path='D:/tmp/New Maize Data set/Vali_maize'
3. 数据增强
由于电脑的配置低,带不动很多图片,所以只选取了每种病害图片的几百张作为训练集,故需要数据增强操作,提高分类准确率。
使用keras提供的图像生成器ImageDataGenerator类来实现数据增强。主要做法是每次取一个批次即batch_size大小的样本数据提供给模型,同时对每批样本进行归一化、随机旋转40°、随机水平和上下位置平移、随机错切变换角度、随机缩放比例、随机将一半图像水平翻转等操作。这样每一轮训练时输入的样本批次就不会完全相同,可以增强模型的泛化能力。
数据增强后的结果如图:
from tensorflow.keras.preprocessing.image import ImageDataGenerator
train_gen = ImageDataGenerator(
rescale=1 / 255,
rotation_range=40, # 角度值,0-180.表示图像随机旋转的角度范围
width_shift_range=0.2, # 平移比例,下同
height_shift_range=0.2,
shear_range=0.2, # 随机错切变换角度
zoom_range=0.2, # 随即缩放比例
horizontal_flip=True, # 随机将一半图像水平翻转
validation_split=0.2,
fill_mode='nearest' # 填充新创建像素的方法
)
train_generator = train_gen.flow_from_directory(
directory=train_path,
shuffle = True,
batch_size = batch_size,
class_mode = 'categorical',
target_size = IMG_SHAPE[:-1],
color_mode='rgb',
#classes =classes,
#subset='training'
)
validation_generator = train_gen.flow_from_directory(
directory=val_path,
shuffle = True,
batch_size = batch_size,
class_mode = 'categorical',
target_size =IMG_SHAPE[:-1],
color_mode='rgb',
#classes =classes,
#subset='validation'
)
4. 构建模型
这里所用模型直接调用keras中的Xception模型
#构建模型
model = tf.keras.Sequential([tf.keras.applications.Xception(input_shape=(150,150,3),weights='imagenet',include_top=False),
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(4,activation='softmax')])
设置迁移学习冻结模型的层数:冻结部分网络层,即只训练其中的一部分网络层。
for i, layer in enumerate(model.layers[0].layers):
if i > 85:
layer.trainable = True
else:
layer.trainable = False
5. 编译模型
#编译模型
model.compile(optimizer='adam',
loss = 'categorical_crossentropy',
metrics=['accuracy'])
6. 打印模型
model.summary()
模型打印结果可以看到可训练的参数数量
7. 训练模型
history=model.fit_generator(train_generator,
steps_per_epoch=max(1, train_generator.n//batch_size),
validation_data=validation_generator,
validation_steps=max(1, validation_generator.n//batch_size),
epochs =100,
#initial_epoch=0,
#callbacks=[checkpoint]
)
8. 保存模型
将模型保存为.h5文件
model.save('model/Xception_2_85_model.h5')
9. 绘制损失值曲线和准确率曲线
# 记录准确率和损失值
history_dict = history.history
train_loss = history_dict["loss"]
train_accuracy = history_dict["acc"]
val_loss = history_dict["val_loss"]
val_accuracy = history_dict["val_acc"]
# 绘制损失值曲线
plt.figure()
plt.title('InceptionV3-1')
plt.plot(range(EPOCHS),train_loss,c='k' ,ls='--',label='train_loss')
plt.plot(range(EPOCHS),val_loss,'k' ,label='val_loss' )
plt.legend()
plt.xlabel('epochs')
plt.ylabel('loss')
import matplotlib as mpl
#中文字体设置
mpl.rcParams["font.family"] = "SimHei"
mpl.rcParams["axes.unicode_minus"] = False
mpl.rcParams["font.style"] = "normal"
mpl.rcParams["font.size"] = 10
# 绘制准确率曲线
plt.figure()
#plt.title('InceptionV3-1')
plt.plot(range(EPOCHS), train_accuracy,ls='--', c="k",label="训练集准确率")
plt.plot(range(EPOCHS), val_accuracy,c="k",label="验证集准确率")
plt.ylim(0.5,1)
plt.legend(loc='lower right')
plt.xlabel("训练轮次")
plt.ylabel("准确率")
plt.show()
10. 测试模型
测试结果可以输出一个混淆矩阵,查看每种病害类别的准确率。
import numpy as np
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import load_model
import datetime
from tensorflow.keras.callbacks import TensorBoard
from keras.backend.tensorflow_backend import set_session
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import itertools
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
set_session(sess)
keras.backend.clear_session() #清理session
#test image directory
dst_path = 'D:/tmp/New Maize Data set/Test_maize'
#model path
model_file ='C:/Users/name/model/Xception_2_85_model.h5'
batch_size = 8
def plot_confusion_matrix(cm,
target_names,
title='Confusion Matrix',
cmap=plt.cm.Greens, # 设置混淆矩阵的颜色主题
normalize=True):
accuracy = np.trace(cm) / float(np.sum(cm))
misclass = 1 - accuracy
if cmap is None:
cmap = plt.get_cmap('Blues')
plt.figure()
plt.imshow(cm, interpolation='nearest', cmap=cmap)
# plt.title(title)
plt.title(title+'\naccuracy={:0.4f}; misclass={:0.4f}'.format(accuracy, misclass))
plt.colorbar()
if target_names is not None:
tick_marks = np.arange(len(target_names))
plt.xticks(tick_marks, target_names, rotation=45)
plt.yticks(tick_marks, target_names)
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
thresh = cm.max() / 1.5 if normalize else cm.max() / 2
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
if normalize:
plt.text(j, i, "{:0.4f}".format(cm[i, j]),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
else:
plt.text(j, i, "{:,}".format(cm[i, j]),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.ylabel('True label')
plt.xlabel('Predicted label')
# load model
model = load_model(model_file)
# generator image
test_datagen = ImageDataGenerator(rescale=1. / 255)
test_generator = test_datagen.flow_from_directory(
dst_path,
target_size=(150, 150),
batch_size=batch_size,
shuffle=False
)
labels = test_generator.class_indices #查看类别的label
#labels = ['blight', 'cercos', 'healthy','rust']
#然后直接用predice_geneorator 可以进行预测
test_generator.reset()
pred = model.predict_generator(test_generator, verbose=1)
# 输出每个图像的预测类别
predicted_class_indices = np.argmax(pred, axis=1)
#测试集的真实类别
true_label= test_generator.classes
#简单画出混淆矩阵
import pandas as pd
table=pd.crosstab(true_label,predicted_class_indices,colnames=['predict'],rownames=['label'])
print(table)
#图片化显示混淆矩阵
conf_mat = confusion_matrix(y_true=true_label,y_pred=predicted_class_indices)
plt.figure()
plot_confusion_matrix(conf_mat, normalize=False, target_names=labels, title='Confusion Matrix')
测试结果如下:可以看出每种类别的识别率都很高