基于改进SE-VGG16-BN的131种水果蔬菜图像分类系统(颜色、品种分级)

本研究介绍了一个基于改进SE-VGG16-BN模型的水果蔬菜图像分类系统,针对131种果蔬进行颜色和品种分级。通过引入注意力机制、细粒度分类和数据增强,提高了模型对颜色特征的感知和分类准确性。系统使用多种损失函数融合,加速了网络收敛,提升了分类效率和准确性,适用于农业、食品安全等领域。
摘要由CSDN通过智能技术生成

1.研究背景与意义

项目参考AAAI Association for the Advancement of Artificial Intelligence

研究背景与意义:

随着计算机视觉和机器学习的快速发展,图像分类成为了一个热门的研究领域。在许多实际应用中,如农业、食品安全和市场调研等领域,对水果图像进行准确分类和品种分级具有重要意义。然而,由于水果的形状、颜色和纹理等特征的多样性,以及光照条件和拍摄角度的变化,水果图像分类面临着许多挑战。

目前,基于深度学习的图像分类方法已经取得了显著的成果。其中,卷积神经网络(CNN)是一种非常有效的方法,可以自动学习图像的特征表示。然而,传统的CNN模型在处理水果图像分类时仍然存在一些问题。首先,传统的CNN模型对于水果图像中的颜色信息没有充分利用,导致分类准确率较低。其次,传统的CNN模型对于水果图像中的品种分级任务并不擅长,无法提供细粒度的分类结果。

因此,本研究旨在基于改进的SE-VGG16-BN模型,实现对131种水果图像的准确分类和品种分级。具体来说,本研究将从以下几个方面进行改进和优化:

首先,本研究将引入注意力机制,以增强模型对水果图像中的颜色信息的感知能力。通过学习图像中不同区域的重要性权重,模型可以更好地捕捉到水果图像中的颜色特征,从而提高分类准确率。

其次,本研究将引入细粒度分类的方法,以实现对水果图像的品种分级。通过在模型中增加额外的分类层,可以将水果图像分为更多的细粒度类别,从而提供更具体和详细的分类结果。

最后,本研究将对数据集进行充分的预处理和增强,以提高模型的鲁棒性和泛化能力。通过对数据集进行旋转、缩放和平移等操作,可以增加模型对不同光照条件和拍摄角度的适应能力,从而提高分类的准确性和稳定性。

本研究的意义在于提供了一种基于改进的SE-VGG16-BN模型的水果图像分类系统,可以在农业、食品安全和市场调研等领域中得到广泛应用。通过准确分类和品种分级,可以帮助农民和市场调研人员更好地了解水果的品质和市场需求,从而提高农产品的质量和市场竞争力。此外,本研究还可以为其他图像分类任务提供借鉴和参考,推动深度学习在计算机视觉领域的发展。

2.图片演示

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

3.视频演示

基于改进SE-VGG16-BN的131种水果蔬菜图像分类系统(颜色、品种分级)_哔哩哔哩_bilibili

4.数据集和训练参数设定

AAAI提供的蔬菜水果数据集包含了131个分类,包含了常见的所有的蔬菜和水果类型,并且根据颜色和类型进行了分级划分。
在这里插入图片描述
我们需要将数据集整理为以下结构:

-----data
   |-----train
   |   |-----class1
   |   |-----class2
   |   |-----...
   |
   |-----val
   |   |-----class1
   |   |-----class2
   |   |-----...
   

(1)为提高训练的效果,加快网络模型的收敛,对两个数据集的花卉图片按照保持长宽比的方式归一化,归一化后的尺寸为224×224×3.
(2)将数据增强后的每类花卉图片数的70%划分为训练集,剩余30%作为测试集.
(3)训练时保留VGG16经 ImageNet 预训练产生的用于特征提取的参数,SE单元模块中用于放缩参数r设置为文献[8]的作者所推荐的16,其余参数均使用正态分布随机值进行初始化.
(4)采用随机梯度下降法来优化模型, batchsize设置为32, epoch设为3000,学习率设为0.001,动量因子设为0.9,权重衰减设为0.000 5.
(5)为了防止过拟合,SE-VGG16 网络模型第6段的两个全连接层的dropout 设置为0.5.
(6)多损失函数融合公式中入参数的值设置为0.5.

5.核心代码讲解

5.1 ConfusionMatrix.py


class ConfusionMatrix(object):
    """
    注意,如果显示的图像不全,是matplotlib版本问题
    本例程使用matplotlib-3.2.1(windows and ubuntu)绘制正常
    需要额外安装prettytable库
    """
    def __init__(self, num_classes: int, labels: list):
        self.matrix = np.zeros((num_classes, num_classes))
        self.num_classes = num_classes
        self.labels = labels

    def update(self, preds, labels):
        for p, t in zip(preds, labels):
            self.matrix[p, t] += 1

    def summary(self):
        # calculate accuracy
        sum_TP = 0
        for i in range(self.num_classes):
            sum_TP += self.matrix[i, i]
        acc = sum_TP / np.sum(self.matrix)
        print("the model accuracy is ", acc)

        # precision, recall, specificity
        table = PrettyTable()
        table.field_names = ["", "Precision", "Recall", "Specificity"]
        for i in range(self.num_classes):
            TP = self.matrix[i, i]
            FP = np.sum(self.matrix[i, :]) - TP
            FN = np.sum(self.matrix[:, i]) - TP
            TN = np.sum(self.matrix) - TP - FP - FN
            Precision = round(TP / (TP + FP), 3) if TP + FP != 0 else 0.
            Recall = round(TP / (TP + FN), 3) if TP + FN != 0 else 0.
            Specificity = round(TN / (TN + FP), 3) if TN + FP != 0 else 0.
            table.add_row([self.labels[i], Precision, Recall, Specificity])
        print(table)

    def plot(self):
        matrix = self.matrix
        print(matrix)
        plt.figure(figsize=(40, 40), dpi=100)  # 设置画布的大小和dpi,为了使图片更加清晰
        plt.imshow(matrix, cmap=plt.cm.Blues)


        # 设置x轴坐标label
        plt.xticks(range(self.num_classes), self.labels, rotation=45)
        # 设置y轴坐标label
        plt.yticks(range(self.num_classes), self.labels)
        # 显示colorbar
        plt.colorbar()
        plt.xlabel('True Labels')
        plt.ylabel('Predicted Labels')
        plt.title('Confusion matrix')

        # 在图中标注数量/概率信息
        thresh = matrix.max() / 2
        for x in range(self.num_classes):
            for y in range(self.num_classes):
                # 注意这里的matrix[y, x]不是matrix[x, y]
                info = int(matrix[y, x])
                plt.text(x, y, info,
                         verticalalignment='center',
                         horizontalalignment='center',
                         color="white" if info > thresh else "black")
        plt.tight_layout()
        plt.savefig('./confusion_matrix.png', format='png')  # 保存图像为png格式
        plt.show()



该程序文件名为ConfusionMatrix.py,主要功能是计算和绘制混淆矩阵。

程序首先导入了所需的库和模块,包括os、json、torch、transforms、datasets、numpy、tqdm、matplotlib和PrettyTable。

然后定义了一个名为ConfusionMatrix的类,该类有以下几个方法:

  • __init__(self, num_classes: int, labels: list):初始化方法,接收分类数和标签列表作为参数,创建一个大小为(num_classes, num_classes)的零矩阵,并保存分类数和标签列表。
  • update(self, preds, labels):更新混淆矩阵的方法,接收预测结果和真实标签作为参数,根据预测结果和真实标签更新混淆矩阵。
  • summary(self):计算并打印模型的准确率、精确度、召回率和特异度。
  • plot(self):绘制混淆矩阵图像,并保存为png格式。

接下来是主程序部分,首先判断是否有可用的GPU,然后定义了数据的预处理方法和数据集路径。

然后创建了一个验证数据集的DataLoader,并加载了预训练的vgg模型权重。

接着读取了类别标签的json文件,并保存了标签列表。

然后创建了一个ConfusionMatrix对象,并将模型设置为评估模式。

在没有梯度的情况下,遍历验证数据集,对每个验证数据进行模型推理,并更新混淆矩阵。

最后调用ConfusionMatrix对象的plot方法绘制混淆矩阵图像,并调用summary方法打印模型的准确率、精确度、召回率和特异度。

5.2 fit.py

# SE模块
class SELayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)  # 全局平均池化
        # 两个全连接层,分别进行降维和升维
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    d
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值