https://mp.csdn.net/mp_blog/creation/editor/132926481
https://mp.csdn.net/mp_blog/creation/editor/132926481
数据集的划分:
读取源数据文件夹,生成划分好的文件夹,分为trian、val、test三个文件夹进行
param train_scale: 训练集比例
param val_scale: 验证集比例
param test_scale: 测试集比例
一般训练集、验证集、测试集的比例为8:1:1,也可对于自己的数据集来进行不同的划分。
def data_set_split(src_data_folder, target_data_folder, train_scale=0.8, val_scale=0.2, test_scale=0.0):
class_names = os.listdir(src_data_folder)
# 在目标目录下创建文件夹
split_names = ['train', 'val', 'test']
for split_name in split_names:
split_path = os.path.join(target_data_folder, split_name)
if os.path.isdir(split_path):
pass
else:
os.mkdir(split_path)
for class_name in class_names:
class_split_path = os.path.join(split_path, class_name)
if os.path.isdir(class_split_path):
pass
else:
os.mkdir(class_split_path)
运行的结果如下:
对于每次运行的结果都需进行一个保存处理,以便于后续的相关处理,养成一个编程人的好习惯。
开始数据集划分
*********************************Cherry___healthy*************************************
Cherry___healthy类按照0.8:0.2:0.0的比例划分完成,一共1000张图片
训练集D:/pycharm/ku/模型训练/new_data/Plant Disease and Pest Dataset\train\Cherry___healthy:801张
验证集D:/pycharm/ku/模型训练/new_data/Plant Disease and Pest Dataset\val\Cherry___healthy:199张
测试集D:/pycharm/ku/模型训练/new_data/Plant Disease and Pest Dataset\test\Cherry___healthy:0张
*********************************Cherry___Powdery_mildew*************************************
Cherry___Powdery_mildew类按照0.8:0.2:0.0的比例划分完成,一共1052张图片
训练集D:/pycharm/ku/模型训练/new_data/Plant Disease and Pest Dataset\train\Cherry___Powdery_mildew:842张
验证集D:/pycharm/ku/模型训练/new_data/Plant Disease and Pest Dataset\val\Cherry___Powdery_mildew:210张
测试集D:/pycharm/ku/模型训练/new_data/Plant Disease and Pest Dataset\test\Cherry___Powdery_mildew:0张
*********************************Corn___Common_rust*************************************
Corn___Common_rust类按照0.8:0.2:0.0的比例划分完成,一共1000张图片
训练集D:/pycharm/ku/模型训练/new_data/Plant Disease and Pest Dataset\train\Corn___Common_rust:801张
验证集D:/pycharm/ku/模型训练/new_data/Plant Disease and Pest Dataset\val\Corn___Common_rust:199张
测试集D:/pycharm/ku/模型训练/new_data/Plant Disease and Pest Dataset\test\Corn___Common_rust:0张
*********************************Corn___healthy*************************************
Corn___healthy类按照0.8:0.2:0.0的比例划分完成,一共1162张图片
训练集D:/pycharm/ku/模型训练/new_data/Plant Disease and Pest Dataset\train\Corn___healthy:930张
验证集D:/pycharm/ku/模型训练/new_data/Plant Disease and Pest Dataset\val\Corn___healthy:232张
测试集D:/pycharm/ku/模型训练/new_data/Plant Disease and Pest Dataset\test\Corn___healthy:0张
*********************************Grape___Black_rot*************************************
CNN的相关训练过程:
接下来把训练好的数据集用于CNN模型的训练。
搭建模型-->归一化处理-->卷积层与池化层循环-->二维转一维-->softmax函数激活softmax对应概率值-->输出模型-->返回模型
训练具体过程如下:
Found 3703 files belonging to 16 classes. ['Cherry___Powdery_mildew', 'Cherry___healthy', 'Corn___Common_rust', 'Corn___healthy', 'Grape___Black_rot', 'Grape___healthy', 'Peach___Bacterial_spot', 'Peach___healthy', 'Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy', 'Potato___Early_blight', 'Potato___healthy', 'Strawberry___Leaf_scorch', 'Strawberry___healthy', 'Tomato___Late_blight', 'Tomato___healthy'] Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= rescaling (Rescaling) (None, 224, 224, 3) 0 _________________________________________________________________ conv2d (Conv2D) (None, 222, 222, 32) 896 _________________________________________________________________ max_pooling2d (MaxPooling2D) (None, 111, 111, 32) 0 _________________________________________________________________ conv2d_1 (Conv2D) (None, 109, 109, 64) 18496 _________________________________________________________________ max_pooling2d_1 (MaxPooling2 (None, 54, 54, 64) 0 _________________________________________________________________ flatten (Flatten) (None, 186624) 0 _________________________________________________________________ dense (Dense) (None, 128) 23888000 _________________________________________________________________ dense_1 (Dense) (None, 16) 2064 ================================================================= Total params: 23,909,456 Trainable params: 23,909,456 Non-trainable params: 0 _________________________________________________________________ Epoch 1/30 930/930 [==============================] - 352s 379ms/step - loss: 1.9201 - accuracy: 0.3683 - val_loss: 1.2696 - val_accuracy: 0.5406 Epoch 2/30 930/930 [==============================] - 351s 377ms/step - loss: 0.8038 - accuracy: 0.7368 - val_loss: 0.7117 - val_accuracy: 0.7710 Epoch 3/30 930/930 [==============================] - 350s 376ms/step - loss: 0.4860 - accuracy: 0.8429 - val_loss: 0.5012 - val_accuracy: 0.8339 Epoch 4/30 930/930 [==============================] - 350s 376ms/step - loss: 0.3434 - accuracy: 0.8916 - val_loss: 0.4698 - val_accuracy: 0.8445 Epoch 5/30 930/930 [==============================] - 348s 375ms/step - loss: 0.2449 - accuracy: 0.9205 - val_loss: 0.2892 - val_accuracy: 0.9060 Epoch 6/30 930/930 [==============================] - 349s 375ms/step - loss: 0.1781 - accuracy: 0.9449 - val_loss: 0.4551 - val_accuracy: 0.8677 Epoch 7/30 930/930 [==============================] - 349s 375ms/step - loss: 0.1165 - accuracy: 0.9642 - val_loss: 0.2212 - val_accuracy: 0.9306 Epoch 8/30 930/930 [==============================] - 1377s 1s/step - loss: 0.0915 - accuracy: 0.9701 - val_loss: 0.2250 - val_accuracy: 0.9317 Epoch 9/30
注意:(其中第二行为你的数据集的所有类名需单独进行保存)
mobilenet的相关训练过程:
同理,也需将训练好的数据集进行加载,然后构建mobilenet模型
将模型的主干参数进行冻结
进行归一化的处理
设置主干模型
对主干模型的输出进行全局平均池化
通过全连接层映射到最后的分类数目上
模型训练的优化器为adam优化器,模型的损失函数为交叉熵损失函数
(可以根据自己的需求更改所需训练的轮数)
Found 3703 files belonging to 16 classes. ['Cherry___Powdery_mildew', 'Cherry___healthy', 'Corn___Common_rust', 'Corn___healthy', 'Grape___Black_rot', 'Grape___healthy', 'Peach___Bacterial_spot', 'Peach___healthy', 'Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy', 'Potato___Early_blight', 'Potato___healthy', 'Strawberry___Leaf_scorch', 'Strawberry___healthy', 'Tomato___Late_blight', 'Tomato___healthy'] Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_224_no_top.h5 9412608/9406464 [==============================] - 35s 4us/step Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= rescaling (Rescaling) (None, 224, 224, 3) 0 _________________________________________________________________ mobilenetv2_1.00_224 (Functi (None, 7, 7, 1280) 2257984 _________________________________________________________________ global_average_pooling2d (Gl (None, 1280) 0 _________________________________________________________________ dense (Dense) (None, 16) 20496 ================================================================= Total params: 2,278,480 Trainable params: 20,496 Non-trainable params: 2,257,984 _________________________________________________________________ Epoch 1/30 930/930 [==============================] - 229s 246ms/step - loss: 0.2130 - accuracy: 0.9543 - val_loss: 0.0653 - val_accuracy: 0.9860 Epoch 2/30 930/930 [==============================] - 235s 252ms/step - loss: 0.0351 - accuracy: 0.9940 - val_loss: 0.0488 - val_accuracy: 0.9846 Epoch 3/30 930/930 [==============================] - 234s 252ms/step - loss: 0.0182 - accuracy: 0.9972 - val_loss: 0.0359 - val_accuracy: 0.9900 Epoch 4/30 930/930 [==============================] - 235s 252ms/step - loss: 0.0103 - accuracy: 0.9990 - val_loss: 0.0329 - val_accuracy: 0.9897 Epoch 5/30 930/930 [==============================] - 235s 253ms/step - loss: 0.0069 - accuracy: 0.9993 - val_loss: 0.0336 - val_accuracy: 0.9884 Epoch 6/30 930/930 [==============================] - 235s 253ms/step - loss: 0.0044 - accuracy: 0.9997 - val_loss: 0.0350 - val_accuracy: 0.9876 Epoch 7/30 930/930 [==============================] - 234s 252ms/step - loss: 0.0030 - accuracy: 0.9999 - val_loss: 0.0279 - val_accuracy: 0.9927 Epoch 8/30 930/930 [==============================] - 235s 252ms/step - loss: 0.0020 - accuracy: 1.0000 - val_loss: 0.0238 - val_accuracy: 0.9943 Epoch 9/30 930/930 [==============================] - 239s 257ms/step - loss: 0.0014 - accuracy: 1.0000 - val_loss: 0.0260 - val_accuracy: 0.9943
以上两个模型训练完成需要进行模型准确率的处理
数据加载,分别从训练的数据集的文件夹和测试的文件夹中加载训练集和验证集
测试mobilenet准确率cnn模型准确率
对模型分开进行推理
将推理对应的标签取出
test_real_labels = []
test_pre_labels = []
for test_batch_images, test_batch_labels in test_ds:
test_batch_labels = test_batch_labels.numpy()
test_batch_pres = model.predict(test_batch_images)
test_batch_labels_max = np.argmax(test_batch_labels, axis=1)
test_batch_pres_max = np.argmax(test_batch_pres, axis=1)
for i in test_batch_labels_max:
test_real_labels.append(i)
for i in test_batch_pres_max:
test_pre_labels.append(i)
绘制相关图表来进行展示
cnn:
mobilenet:
执行结果:
Found 3703 files belonging to 16 classes. 232/232 [==============================] - 45s 192ms/step - loss: 0.0277 - accuracy: 0.9951 Mobilenet test accuracy : 0.9951390624046326 [[208. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 1.] [ 0. 198. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.] [ 0. 0. 199. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] [ 0. 0. 0. 232. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] [ 0. 0. 0. 0. 235. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] [ 0. 0. 0. 0. 0. 199. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] [ 0. 0. 0. 0. 0. 0. 299. 0. 0. 0. 0. 0. 0. 0. 1. 0.] [ 0. 0. 0. 0. 0. 0. 0. 199. 0. 0. 0. 0. 0. 0. 0. 0.] [ 0. 0. 0. 0. 0. 0. 0. 0. 196. 3. 0. 0. 0. 0. 0. 0.] [ 0. 0. 0. 0. 0. 0. 0. 0. 3. 292. 0. 0. 0. 0. 0. 0.] [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 198. 0. 0. 0. 0. 1.] [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 197. 0. 1. 0. 0.] [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 219. 0. 2. 0.] [ 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 198. 0. 0.] [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 299. 1.] [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 317.]] [[0.99047619 0. 0. 0. 0. 0. 0. 0.0047619 0. 0. 0. 0. 0. 0. 0. 0.0047619 ] [0. 0.99497487 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.00502513 0. 0. 0. 0. ] [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ] [0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ] [0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ] [0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ] [0. 0. 0. 0. 0. 0. 0.99666667 0. 0. 0. 0. 0. 0. 0. 0.00333333 0. ] [0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. ] [0. 0. 0. 0. 0. 0. 0. 0. 0.98492462 0.01507538 0. 0. 0. 0. 0. 0. ] [0. 0. 0. 0. 0. 0. 0. 0. 0.01016949 0.98983051 0. 0. 0. 0. 0. 0. ] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.99497487 0. 0. 0. 0. 0.00502513] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.00502513 0. 0.98994975 0. 0.00502513 0. 0. ] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.99095023 0. 0.00904977 0. ] [0. 0. 0. 0. 0. 0.00502513 0. 0. 0. 0. 0. 0. 0. 0.99497487 0. 0. ] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.99666667 0.00333333] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.00314465 0. 0.99685535]] Found 14875 files belonging to 16 classes. Found 3703 files belonging to 16 classes. 232/232 [==============================] - 20s 86ms/step - loss: 0.2065 - accuracy: 0.9576 CNN test accuracy : 0.9576019644737244 [[200. 0. 0. 0. 0. 0. 2. 1. 0. 0. 0. 0. 0. 1. 6. 0.] [ 0. 189. 0. 0. 0. 2. 0. 1. 0. 1. 0. 6. 0. 0. 0. 0.] [ 0. 0. 198. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.] [ 0. 0. 1. 231. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] [ 0. 0. 0. 0. 230. 2. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0.] [ 0. 1. 0. 1. 0. 188. 0. 0. 0. 2. 0. 4. 2. 1. 0. 0.] [ 0. 0. 0. 0. 2. 0. 283. 2. 0. 1. 2. 0. 2. 0. 8. 0.] [ 0. 0. 0. 0. 0. 0. 1. 197. 0. 1. 0. 0. 0. 0. 0. 0.] [ 0. 1. 0. 0. 0. 2. 4. 1. 175. 5. 2. 2. 0. 1. 5. 1.] [ 0. 3. 0. 1. 0. 5. 0. 0. 3. 282. 0. 0. 0. 0. 1. 0.] [ 0. 0. 0. 0. 0. 0. 2. 0. 0. 0. 194. 0. 2. 0. 1. 0.] [ 1. 2. 0. 0. 0. 2. 2. 0. 0. 1. 0. 191. 0. 0. 0. 0.] [ 1. 0. 1. 0. 1. 0. 5. 0. 0. 0. 1. 0. 211. 0. 1. 0.] [ 0. 0. 1. 1. 0. 1. 0. 0. 1. 0. 0. 4. 1. 190. 0. 0.] [ 0. 0. 0. 0. 2. 3. 3. 1. 7. 2. 8. 0. 3. 0. 271. 0.] [ 0. 0. 0. 0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 316.]] [[0.95238095 0. 0. 0. 0. 0. 0.00952381 0.0047619 0. 0. 0. 0. 0. 0.0047619 0.02857143 0. ] [0. 0.94974874 0. 0. 0. 0.01005025 0. 0.00502513 0. 0.00502513 0. 0.03015075 0. 0. 0. 0. ] [0. 0. 0.99497487 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.00502513 0. ] [0. 0. 0.00431034 0.99568966 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ] [0. 0. 0. 0. 0.9787234 0.00851064 0.00425532 0.00425532 0. 0. 0. 0. 0.00425532 0. 0. 0. ] [0. 0.00502513 0. 0.00502513 0. 0.94472362 0. 0. 0. 0.01005025 0. 0.0201005 0.01005025 0.00502513 0. 0. ] [0. 0. 0. 0. 0.00666667 0. 0.94333333 0.00666667 0. 0.00333333 0.00666667 0. 0.00666667 0. 0.02666667 0. ] [0. 0. 0. 0. 0. 0. 0.00502513 0.98994975 0. 0.00502513 0. 0. 0. 0. 0. 0. ] [0. 0.00502513 0. 0. 0. 0.01005025 0.0201005 0.00502513 0.87939698 0.02512563 0.01005025 0.01005025 0. 0.00502513 0.02512563 0.00502513] [0. 0.01016949 0. 0.00338983 0. 0.01694915 0. 0. 0.01016949 0.9559322 0. 0. 0. 0. 0.00338983 0. ] [0. 0. 0. 0. 0. 0. 0.01005025 0. 0. 0. 0.97487437 0. 0.01005025 0. 0.00502513 0. ] [0.00502513 0.01005025 0. 0. 0. 0.01005025 0.01005025 0. 0. 0.00502513 0. 0.95979899 0. 0. 0. 0. ] [0.00452489 0. 0.00452489 0. 0.00452489 0. 0.02262443 0. 0. 0. 0.00452489 0. 0.95475113 0. 0.00452489 0. ] [0. 0. 0.00502513 0.00502513 0. 0.00502513 0. 0. 0.00502513 0. 0. 0.0201005 0.00502513 0.95477387 0. 0. ] [0. 0. 0. 0. 0.00666667 0.01 0.01 0.00333333 0.02333333 0.00666667 0.02666667 0. 0.01 0. 0.90333333 0. ] [0. 0. 0. 0. 0. 0. 0.00314465 0. 0.00314465 0. 0. 0. 0. 0. 0. 0.99371069]]