1.研究背景与意义
项目参考AAAI Association for the Advancement of Artificial Intelligence
研究背景与意义:
随着计算机视觉和机器学习的快速发展,图像分类成为了一个热门的研究领域。在许多实际应用中,如农业、食品安全和市场调研等领域,对水果图像进行准确分类和品种分级具有重要意义。然而,由于水果的形状、颜色和纹理等特征的多样性,以及光照条件和拍摄角度的变化,水果图像分类面临着许多挑战。
目前,基于深度学习的图像分类方法已经取得了显著的成果。其中,卷积神经网络(CNN)是一种非常有效的方法,可以自动学习图像的特征表示。然而,传统的CNN模型在处理水果图像分类时仍然存在一些问题。首先,传统的CNN模型对于水果图像中的颜色信息没有充分利用,导致分类准确率较低。其次,传统的CNN模型对于水果图像中的品种分级任务并不擅长,无法提供细粒度的分类结果。
因此,本研究旨在基于改进的SE-VGG16-BN模型,实现对131种水果图像的准确分类和品种分级。具体来说,本研究将从以下几个方面进行改进和优化:
首先,本研究将引入注意力机制,以增强模型对水果图像中的颜色信息的感知能力。通过学习图像中不同区域的重要性权重,模型可以更好地捕捉到水果图像中的颜色特征,从而提高分类准确率。
其次,本研究将引入细粒度分类的方法,以实现对水果图像的品种分级。通过在模型中增加额外的分类层,可以将水果图像分为更多的细粒度类别,从而提供更具体和详细的分类结果。
最后,本研究将对数据集进行充分的预处理和增强,以提高模型的鲁棒性和泛化能力。通过对数据集进行旋转、缩放和平移等操作,可以增加模型对不同光照条件和拍摄角度的适应能力,从而提高分类的准确性和稳定性。
本研究的意义在于提供了一种基于改进的SE-VGG16-BN模型的水果图像分类系统,可以在农业、食品安全和市场调研等领域中得到广泛应用。通过准确分类和品种分级,可以帮助农民和市场调研人员更好地了解水果的品质和市场需求,从而提高农产品的质量和市场竞争力。此外,本研究还可以为其他图像分类任务提供借鉴和参考,推动深度学习在计算机视觉领域的发展。
2.图片演示
3.视频演示
基于改进SE-VGG16-BN的131种水果蔬菜图像分类系统(颜色、品种分级)_哔哩哔哩_bilibili
4.数据集和训练参数设定
AAAI提供的蔬菜水果数据集包含了131个分类,包含了常见的所有的蔬菜和水果类型,并且根据颜色和类型进行了分级划分。
我们需要将数据集整理为以下结构:
-----data
|-----train
| |-----class1
| |-----class2
| |-----...
|
|-----val
| |-----class1
| |-----class2
| |-----...
(1)为提高训练的效果,加快网络模型的收敛,对两个数据集的花卉图片按照保持长宽比的方式归一化,归一化后的尺寸为224×224×3.
(2)将数据增强后的每类花卉图片数的70%划分为训练集,剩余30%作为测试集.
(3)训练时保留VGG16经 ImageNet 预训练产生的用于特征提取的参数,SE单元模块中用于放缩参数r设置为文献[8]的作者所推荐的16,其余参数均使用正态分布随机值进行初始化.
(4)采用随机梯度下降法来优化模型, batchsize设置为32, epoch设为3000,学习率设为0.001,动量因子设为0.9,权重衰减设为0.000 5.
(5)为了防止过拟合,SE-VGG16 网络模型第6段的两个全连接层的dropout 设置为0.5.
(6)多损失函数融合公式中入参数的值设置为0.5.
5.核心代码讲解
5.1 ConfusionMatrix.py
class ConfusionMatrix(object):
"""
注意,如果显示的图像不全,是matplotlib版本问题
本例程使用matplotlib-3.2.1(windows and ubuntu)绘制正常
需要额外安装prettytable库
"""
def __init__(self, num_classes: int, labels: list):
self.matrix = np.zeros((num_classes, num_classes))
self.num_classes = num_classes
self.labels = labels
def update(self, preds, labels):
for p, t in zip(preds, labels):
self.matrix[p, t] += 1
def summary(self):
# calculate accuracy
sum_TP = 0
for i in range(self.num_classes):
sum_TP += self.matrix[i, i]
acc = sum_TP / np.sum(self.matrix)
print("the model accuracy is ", acc)
# precision, recall, specificity
table = PrettyTable()
table.field_names = ["", "Precision", "Recall", "Specificity"]
for i in range(self.num_classes):
TP = self.matrix[i, i]
FP = np.sum(self.matrix[i, :]) - TP
FN = np.sum(self.matrix[:, i]) - TP
TN = np.sum(self.matrix) - TP - FP - FN
Precision = round(TP / (TP + FP), 3) if TP + FP != 0 else 0.
Recall = round(TP / (TP + FN), 3) if TP + FN != 0 else 0.
Specificity = round(TN / (TN + FP), 3) if TN + FP != 0 else 0.
table.add_row([self.labels[i], Precision, Recall, Specificity])
print(table)
def plot(self):
matrix = self.matrix
print(matrix)
plt.figure(figsize=(40, 40), dpi=100) # 设置画布的大小和dpi,为了使图片更加清晰
plt.imshow(matrix, cmap=plt.cm.Blues)
# 设置x轴坐标label
plt.xticks(range(self.num_classes), self.labels, rotation=45)
# 设置y轴坐标label
plt.yticks(range(self.num_classes), self.labels)
# 显示colorbar
plt.colorbar()
plt.xlabel('True Labels')
plt.ylabel('Predicted Labels')
plt.title('Confusion matrix')
# 在图中标注数量/概率信息
thresh = matrix.max() / 2
for x in range(self.num_classes):
for y in range(self.num_classes):
# 注意这里的matrix[y, x]不是matrix[x, y]
info = int(matrix[y, x])
plt.text(x, y, info,
verticalalignment='center',
horizontalalignment='center',
color="white" if info > thresh else "black")
plt.tight_layout()
plt.savefig('./confusion_matrix.png', format='png') # 保存图像为png格式
plt.show()
该程序文件名为ConfusionMatrix.py,主要功能是计算和绘制混淆矩阵。
程序首先导入了所需的库和模块,包括os、json、torch、transforms、datasets、numpy、tqdm、matplotlib和PrettyTable。
然后定义了一个名为ConfusionMatrix的类,该类有以下几个方法:
__init__(self, num_classes: int, labels: list)
:初始化方法,接收分类数和标签列表作为参数,创建一个大小为(num_classes, num_classes)的零矩阵,并保存分类数和标签列表。update(self, preds, labels)
:更新混淆矩阵的方法,接收预测结果和真实标签作为参数,根据预测结果和真实标签更新混淆矩阵。summary(self)
:计算并打印模型的准确率、精确度、召回率和特异度。plot(self)
:绘制混淆矩阵图像,并保存为png格式。
接下来是主程序部分,首先判断是否有可用的GPU,然后定义了数据的预处理方法和数据集路径。
然后创建了一个验证数据集的DataLoader,并加载了预训练的vgg模型权重。
接着读取了类别标签的json文件,并保存了标签列表。
然后创建了一个ConfusionMatrix对象,并将模型设置为评估模式。
在没有梯度的情况下,遍历验证数据集,对每个验证数据进行模型推理,并更新混淆矩阵。
最后调用ConfusionMatrix对象的plot方法绘制混淆矩阵图像,并调用summary方法打印模型的准确率、精确度、召回率和特异度。
5.2 fit.py
# SE模块
class SELayer(nn.Module):
def __init__(self, channel, reduction=16):
super(SELayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1) # 全局平均池化
# 两个全连接层,分别进行降维和升维
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid()
)
d