基于Fruits-360数据集构建CNN进行水果识别实验

基于Fruits-360数据集的水果识别项目

前段时间导师要求做一个神经网络可视化的项目,要将水果数据集进行训练得到模型,用于TensorSpace可视化。前前后后捣鼓了很久,现在回过头总结一下整个项目过程,写下这篇博客记录遇到的问题,有任何问题欢迎在评论区留言。

  • 1.实验数据集

  • (1)实验用的数据集是最常见的Fruits-360水果数据集,截至写博客为止这个数据集最新版本是2020.05.18.0
  • (2)该数据集有131种水果分类(分的特别细致,比方说苹果(不同品种:深红色雪,金黄色,金红色,史密斯奶奶,粉红色淑女,红色,红色美味)):Apples (different varieties: Crimson Snow, Golden, Golden-Red, Granny Smith, Pink Lady, Red, Red Delicious), Apricot, Avocado, Avocado ripe, Banana (Yellow, Red, Lady Finger), Beetroot Red, Blueberry, Cactus fruit, Cantaloupe (2 varieties), Carambula, Cauliflower, Cherry (different varieties, Rainier), Cherry Wax (Yellow, Red, Black), Chestnut, Clementine, Cocos, Corn (with husk), Cucumber (ripened), Dates, Eggplant, Fig, Ginger Root, Granadilla, Grape (Blue, Pink, White (different varieties)), Grapefruit (Pink, White), Guava, Hazelnut, Huckleberry, Kiwi, Kaki, Kohlrabi, Kumsquats, Lemon (normal, Meyer), Lime, Lychee, Mandarine, Mango (Green, Red), Mangostan, Maracuja, Melon Piel de Sapo, Mulberry, Nectarine (Regular, Flat), Nut (Forest, Pecan), Onion (Red, White), Orange, Papaya, Passion fruit, Peach (different varieties), Pepino, Pear (different varieties, Abate, Forelle, Kaiser, Monster, Red, Stone, Williams), Pepper (Red, Green, Orange, Yellow), Physalis (normal, with Husk), Pineapple (normal, Mini), Pitahaya Red, Plum (different varieties), Pomegranate, Pomelo Sweetie, Potato (Red, Sweet, White), Quince, Rambutan, Raspberry, Redcurrant, Salak, Strawberry (normal, Wedge), Tamarillo, Tangelo, Tomato (different varieties, Maroon, Cherry Red, Yellow, not ripened, Heart), Walnut, Watermelon.
  • (3)数据集属性
    • 图片总数:90483。
    • 训练集大小:67692张图像(每张图像一个水果)
    • 测试集大小:22688张图像(每张图像一个水果)
    • 种类数:131(水果)
    • 图片尺寸:100 x 100像素。
    • 文件名格式:图像索引_100.jpg(例如32_100.jpg)或r_图像索引_100.jpg(例如r_32_100.jpg)或r2_图像索引_100.jpg或r3_图像索引_100.jpg。“ r”代表旋转的水果。“ r2”表示水果绕第三轴旋转。“100”来自图像尺寸(100x100像素)。
    • 同一水果(例如苹果)的不同品种被存储为属于不同类别。
  • (4)数据集下载链接:Fruits-360.
  • 2.导包

实验用到了以下包:

// 导包
import numpy as np
import pandas as pd 
import matplotlib.pyplot as plt 
import glob 
import cv2 
import tensorflow as tf
import keras
import os
print(os.listdir("../dataset/Training"))
['Apple Braeburn', 'Apple Crimson Snow', 'Apple Golden 1', 'Apple Golden 2', 'Apple Golden 3', 'Apple Granny Smith', 'Apple Pink Lady', 'Apple Red 1', 'Apple Red 2', 'Apple Red 3', 'Apple Red Delicious', 'Apple Red Yellow 1', 'Apple Red Yellow 2', 'Apricot', 'Avocado', 'Avocado ripe', 'Banana', 'Banana Lady Finger', 'Banana Red', 'Beetroot', 'Blueberry', 'Cactus fruit', 'Cantaloupe 1', 'Cantaloupe 2', 'Carambula', 'Cauliflower', 'Cherry 1', 'Cherry 2', 'Cherry Rainier', 'Cherry Wax Black', 'Cherry Wax Red', 'Cherry Wax Yellow', 'Chestnut', 'Clementine', 'Cocos', 'Corn', 'Corn Husk', 'Cucumber Ripe', 'Cucumber Ripe 2', 'Dates', 'Eggplant', 'Fig', 'Ginger Root', 'Granadilla', 'Grape Blue', 'Grape Pink', 'Grape White', 'Grape White 2', 'Grape White 3', 'Grape White 4', 'Grapefruit Pink', 'Grapefruit White', 'Guava', 'Hazelnut', 'Huckleberry', 'Kaki', 'Kiwi', 'Kohlrabi', 'Kumquats', 'Lemon', 'Lemon Meyer', 'Limes', 'Lychee', 'Mandarine', 'Mango', 'Mango Red', 'Mangostan', 'Maracuja', 'Melon Piel de Sapo', 'Mulberry', 'Nectarine', 'Nectarine Flat', 'Nut Forest', 'Nut Pecan', 'Onion Red', 'Onion Red Peeled', 'Onion White', 'Orange', 'Papaya', 'Passion Fruit', 'Peach', 'Peach 2', 'Peach Flat', 'Pear', 'Pear 2', 'Pear Abate', 'Pear Forelle', 'Pear Kaiser', 'Pear Monster', 'Pear Red', 'Pear Stone', 'Pear Williams', 'Pepino', 'Pepper Green', 'Pepper Orange', 'Pepper Red', 'Pepper Yellow', 'Physalis', 'Physalis with Husk', 'Pineapple', 'Pineapple Mini', 'Pitahaya Red', 'Plum', 'Plum 2', 'Plum 3', 'Pomegranate', 'Pomelo Sweetie', 'Potato Red', 'Potato Red Washed', 'Potato Sweet', 'Potato White', 'Quince', 'Rambutan', 'Raspberry', 'Redcurrant', 'Salak', 'Strawberry', 'Strawberry Wedge', 'Tamarillo', 'Tangelo', 'Tomato 1', 'Tomato 2', 'Tomato 3', 'Tomato 4', 'Tomato Cherry Red', 'Tomato Heart', 'Tomato Maroon', 'Tomato not Ripened', 'Tomato Yellow', 'Walnut', 'Watermelon']
  • 3.数据预处理

  • (1)训练集处理

// 训练集处理
training_fruit_img = []
training_label = []
for dir_path in glob.glob("../dataset/Training/*"):
    img_label = dir_path.split("/")[-1]
    for img_path in glob.glob(os.path.join(dir_path, "*.jpg")):
        img = cv2.imread(img_path)
        img = cv2.resize(img, (64, 64))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        training_fruit_img.append(img)
        training_label.append(img_label)
training_fruit_img = np.array(training_fruit_img)
training_label = np.array(training_label)
print(len(np.unique(training_label)))
131
  • (2)测试集处理

// 测试集处理
test_fruit_img = []
test_label = []
for dir_path in glob.glob("../dataset/Test/*"):
    img_label = dir_path.split("/")[-1]
    for img_path in glob.glob(os.path.join(dir_path, "*.jpg")):
        img = cv2.imread(img_path)
        img = cv2.resize(img, (64, 64))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        test_fruit_img.append(img)
        test_label.append(img_label)
test_fruit_img = np.array(test_fruit_img)
test_label = np.array(test_label)
print(len(np.unique(test_label)))
131
  • (3)测试集(混合水果图像)处理

// 测试集(混合水果图像)处理
test_fruits_img = []
tests_label = []
for img_path in glob.glob(os.path.join("../dataset/test-multiple_fruits", "*.jpg")):
     img_label = img_path.split("/")[-1]
     img = cv2.imread(img_path)
     img = cv2.resize(img, (64, 64))
     img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
     test_fruits_img.append(img)
     tests_label.append(img_label)
test_fruits_img = np.array(test_fruits_img)
tests_label = np.array(tests_label)
len(np.unique(tests_label))
103
  • (4)训练集标签和id互相转化

trainging_label_to_id = {v : k for k, v in enumerate(np.unique(training_label))}
training_id_to_label = {v : k for k, v in trainging_label_to_id.items()}
training_label_id = np.array([trainging_label_to_id[i] for i in training_label])
print(training_label_id)
array([  0,   0,   0, ..., 130, 130, 130])
  • (5)测试集标签和id互相转化

  • 这部分同训练集,也是使用enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,即将测试集标签去重后组合为一个索引序列,并进行键值对对换,再转化为NumPy数组。
  • (6)像素值缩放

  • 这里我将训练集图像和测试集图像乘以1/255进行缩放,将像素值缩放至[0, 1]区间。
  • 然后使用matplotlib.pyplot将图像显示出来(实验中我以训练集第10001个图像为例)
    训练集第10001个图像
  • 4.Keras模型构建

  • 我实验中的搭建的模型结构如下:
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_1 (Conv2D)            (None, 64, 64, 16)        448       
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 32, 32, 16)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 32, 32, 32)        4640      
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 16, 16, 32)        0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 16, 16, 32)        9248      
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 8, 8, 32)          0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 8, 8, 64)          18496     
_________________________________________________________________
max_pooling2d_4 (MaxPooling2 (None, 4, 4, 64)          0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 1024)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 256)               262400    
_________________________________________________________________
dense_2 (Dense)              (None, 131)               33667     
=================================================================
Total params: 328,899
Trainable params: 328,899
Non-trainable params: 0
_________________________________________________________________
  • 5.模型编译和训练

  • 我实验中使用的是交叉熵损失函数sparse_categorical_crossentropy,优化器使用的是Adamaxbatch_size设置为128,epochs设置为10,并且使用TensorBoard将训练过程可视化。
Train on 67692 samples
Epoch 1/10
67692/67692 [==============================] - 145s 2ms/sample - loss: 1.6539 - accuracy: 0.5947
Epoch 2/10
67692/67692 [==============================] - 156s 2ms/sample - loss: 0.2184 - accuracy: 0.9370
Epoch 3/10
67692/67692 [==============================] - 179s 3ms/sample - loss: 0.0812 - accuracy: 0.9764
Epoch 4/10
67692/67692 [==============================] - 264s 4ms/sample - loss: 0.0466 - accuracy: 0.9864
Epoch 5/10
67692/67692 [==============================] - 272s 4ms/sample - loss: 0.0257 - accuracy: 0.9932
Epoch 6/10
67692/67692 [==============================] - 257s 4ms/sample - loss: 0.0160 - accuracy: 0.9958
Epoch 7/10
67692/67692 [==============================] - 301s 4ms/sample - loss: 0.0164 - accuracy: 0.9956
Epoch 8/10
67692/67692 [==============================] - 289s 4ms/sample - loss: 0.0094 - accuracy: 0.9976
Epoch 9/10
67692/67692 [==============================] - 252s 4ms/sample - loss: 0.0113 - accuracy: 0.9967
Epoch 10/10
67692/67692 [==============================] - 245s 4ms/sample - loss: 0.0062 - accuracy: 0.9982
  • 6.在测试集上评估模型

22688/22688 [==============================] - 30s 1ms/step


Loss: 0.25497715034207047
Accuracy: 0.9393512010574341
  • 可以看到该模型精确率达到93.94%,效果相对来说还算是不错的。
  • 7.保存模型

  • 将训练完成的模型进行保存,得到一个.h5文件,这个文件也是后续TensorSpace项目所需要用到的核心文件。
model.save("../model/model_demo.h5")
  • 11
    点赞
  • 77
    收藏
    觉得还不错? 一键收藏
  • 34
    评论
1. 数据集介绍 fruits 360是一个开源的水果图像数据集,包含了75种不同的水果,共约8万张图片。每种水果的图片数量不同,最多的是苹果(约7,000张),最少的是柠檬(约200张)。数据集中的图片都是经过调整大小和中心裁剪的,大小为100x100像素。数据集中的每种水果都有多个变体,例如不同成熟度的香蕉、不同颜色的苹果等等。 2. 算法设计 本算法采用卷积神经网络(CNN进行图像分类。CNN是一种特殊的神经网络,可以自动提取图像中的特征,并将其用于分类。CNN的核心是卷积层和池化层,可以有效地减少参数数量,从而避免过拟合现象。此外,本算法还采用了数据增强技术,对训练集进行随机旋转、翻转、缩放等操作,以增加模型的鲁棒性。 3. 算法实现 本算法使用PyTorch框架进行实现。具体实现过程如下: 3.1 数据预处理 将fruits 360数据集下载到本地,并将其分为训练集和测试集。使用PyTorch提供的transforms模块对数据进行预处理,包括调整大小、随机旋转、随机水平翻转、随机竖直翻转、随机裁剪等操作。为了防止过拟合,训练集还进行了随机缩放操作。最终得到了训练集和测试集的数据加载器。 3.2 网络设计 本算法采用了一个简单的卷积神经网络,包括3个卷积层、3个池化层和3个全连接层。卷积层的卷积核大小为3x3,步长为1,补零为1,激活函数为ReLU;池化层的池化核大小为2x2,步长为2;全连接层的输出大小为75,即水果的种类数。具体网络结构如下: Conv2d(3, 32, 3, padding=1) ReLU(inplace=True) MaxPool2d(2, 2) Conv2d(32, 64, 3, padding=1) ReLU(inplace=True) MaxPool2d(2, 2) Conv2d(64, 128, 3, padding=1) ReLU(inplace=True) MaxPool2d(2, 2) Flatten() Linear(128 * 12 * 12, 512) ReLU(inplace=True) Linear(512, 256) ReLU(inplace=True) Linear(256, 75) 3.3 模型训练 采用交叉熵损失函数和随机梯度下降(SGD)优化器进行模型训练。初始学习率为0.01,每20个epoch衰减一次为原来的0.1。训练过程中,每个epoch会计算训练集和测试集的损失和准确率,并将结果保存到日志文件中。 4. 实验结果 经过100个epoch的训练,本算法在测试集上的准确率达到了96.8%。部分预测结果如下图所示: ![image](https://github.com/ShiniuPython/fruit_classification/blob/master/result.png) 可以看到,本算法在大多数情况下都能正确识别水果的种类。但是有些水果的不同变体之间相似度较高,如橙子和柠檬,有时候难以区分。此外,本算法对于水果的形状、颜色等变化较大的情况下也有一定的识别误差。 5. 总结 本算法采用了卷积神经网络进行图像分类,通过数据增强技术提高了模型的鲁棒性。实验结果表明,本算法可以有效地识别大多数水果的种类。但是,对于一些相似度较高的水果和变化较大的水果,还需要进一步改进。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 34
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值