262 种不同水果的 225,640 张图像的数据集
包含绝大多数流行和已知水果的数据集
包括以下水果类型/标签/分支:
阿比乌, 巴西莓, 针叶樱桃, 西非荔枝, 鳄鱼苹果, 琥珀, 苹果, 杏, 阿拉扎, 鳄梨, 贝尔, 香蕉, 巴巴丁, 小檗, 杨梅, 海滩李子, 熊果, 甜椒, 槟榔, 比格奈, bilimbi, 苦瓜, 黑浆果, 黑樱桃, 黑醋栗, 黑桑葚, 黑萨波特, 蓝莓, bolwarra, 瓶葫芦, 巴西坚果, 面包果, 佛手, 水牛莓, burdekin 李子, 缅甸葡萄, caimito, 卡姆果, 卡尼斯尔, 哈密瓜, 开普醋栗, 杨桃, 卡顿, 腰果, 雪松湾樱桃, cempedak, 锡兰醋栗, 切, 切内特, 樱桃, 樱桃, 奇科, 野樱桃, 克莱门汀, 云莓, 无花果, 可可豆, 椰子, 咖啡, 普通沙棘, 玉米仁, 山茱萸樱桃, 蟹苹果, 蔓越莓, 岩高莓, 古布阿苏, 奶油苹果, 达姆森, 日期, 沙漠无花果, 沙漠石灰, 露莓, 火龙果, 榴莲, 茄子, 接骨木果, 象苹果, 余甘, entawak, etrog, 费费约亚, 纤维缎, 无花果, 手指石灰, 加利亚甜瓜, 甘达利亚, 基因, 枸杞, 醋栗, 古米, 葡萄, 葡萄柚, 格雷参与, 石榴, 瓜纳巴纳, 瓜拉纳, 番石榴, 番石榴, 朴树, 硬猕猴桃, 山楂, 猪李, 蜜莓, 金银花, 角瓜, 伊拉瓦拉李子, 印度杏仁, 印度草莓, 伊塔棕榈, jaboticaba, 菠萝蜜, 墨西哥胡椒, 牙买加樱桃, jambul, 日本葡萄干, 茉莉花, jatoba, jocote, jostaberry, 枣, 杜松子, 泰国青柠, 卡希卡蒂亚、卡卡杜李、吉宝、猕猴桃、金桔、昆东、库杰拉、lablab、langsat、lapsi、柠檬、柠檬白杨、白花、lillipilli、青柠、越橘、罗甘莓、龙眼、枇杷、蛋黄果、卢洛、荔枝、马波罗、澳洲坚果、马来苹果、马梅苹果、柑橘、芒果、山竹、马尼拉罗望子、马兰、马林、五月、五月、枸杞、梅林乔、甜瓜梨、米迪姆、奇迹果、假草莓、罗汉果、monstera deliciosa、巴杇达、山木瓜、山番荔枝、蒙杜、 甜瓜, 桃金娘, 南斯, 保姆莓, 纳兰吉拉, 本地樱桃, 本地醋栗, 油桃, 印楝, 农古, 肉豆蔻, 油棕, 旧世界梧桐树, 橄榄, 橙子, 俄勒冈葡萄, 奥塔黑苹果, 木瓜, 百香果, 木瓜, 豌豆, 花生, 梨, pequi, 柿子, 鸽子李子, 猪脸, 霹雳果, 菠萝, 菠萝, 菠萝, pitomba, 李子, 罗汉松, 石榴, 柚子, prikly pear, pulasan, 南瓜, pupunha, 紫苹果浆果, quandong, 木瓜, 红毛丹, 朗布尔、覆盆子、红桑、红醋栗、肋骨、脊葫芦、芮木、玫瑰果、玫瑰桃、玫瑰叶荆棘、仙人掌、萨拉克、盐沼、鲑鱼莓、砂纸无花果、桑托尔、人心果、萨斯卡通、沙棘、海葡萄、雪莓、松科亚、草莓、草莓番石榴、糖苹果、苏里南樱桃、梧桐无花果、罗望子、橘子、丹戎、红豆杉、泰伯里、德克萨斯柿子、顶针莓、番茄、toyon、ugli 水果、香草、天鹅绒罗望子、西瓜、 蜡瓜、白杨、白醋栗、白桑葚、白沙波特、酒莓、旺吉、雅利梨、黄梅、柚子、锯齿形藤蔓、西葫芦
数据集属性
图像总数:225,640。
类数: 262 种水果。
每个标签的图像数量:平均值:861,中位数:1007,标准差:276。(初始目标是每个标签 1,000 个)
图像宽度:平均:213,中位数:209,标准差:19。
图像高度:平均:262,中位数:255,标准差:30。
初始 1000 个目标中缺失的图像:平均值:580,中位数:567,标准差:258。
格式:目录名称表示一个标签,在每个目录中,该标签下的所有图像数据
同一水果的不同品种通常存储在同一个目录中(例如:青苹果、黄苹果和红苹果)。
数据集中的水果图像可以包含水果在其生命周期的所有阶段,也可以包含水果的切片。
图像包含至少 50% 的水果信息(根据手动筛选选择范例)。
图像的背景可以是任何东西(由于数据的性质):单色背景、人手、水果的自然栖息地、树叶等。
没有重复的图像,但有一些图像(同一标签的)具有高度的相似性。
图像可以包含小水印。
一些具有 50-100 张可用图像的水果仍保留在数据集中,但可以丢弃以获得更好的平衡和减少种类。这也是上面提供的缺失图像统计数据差异很大的一个重要原因。
调整大小的数据集
在“resize.py”脚本的帮助下,原始数据集已在 13x16、26x32、52x64、104x128、208x256 维度上进行了标准化。
这些特定维度来自上面提供的平均维度统计信息。目标是尽可能少地丢失水果信息。
有用的脚本部分中提供了用于在脚本中调整大小的算法。
补充说明:
调整大小的过程可能很微妙,也需要很长时间。
在多个实验中使用相同的调整大小的数据集将在比较结果时提供更好的见解。
如果有人想使用原始数据集,仍然可以完成。
如果有人想要不同的维度但使用相同的算法,他们可以使用 “resize.py” 脚本来获取它们。
更多解释可以在剧本或论文中找到。
有用的脚本 - 简短描述
resize.py - 使用抗锯齿功能调整给定尺寸的图像大小以获得更好的质量。
renumber.py - 对目录中的文件重新编号。
label_dictionary.py - 创建用于标签编号映射的字典。
statistics.py - 生成有关数据集的统计信息。
efficient_storage.py - 将图像转换为大型 numpy 数组,并将其存储为“.npz”格式。有助于缩短 IO 加载时间。
test_model.py - 针对一组“.npz”数据测试模型,并生成前 1、5 和 10 预测的每个标签准确性的统计数据。
load_and_predict.py - 带有 GUI 的应用程序,允许选择模型和图像并执行预测和统计。
补充说明:
某些脚本仅在使用 HDD 进行存储时才值得运行。
模块和库的 Python 版本为 3.7。
用于 Tensorflow 的 Direct-ml (AMD GPU) 库。否则,将“tensorflow.compat.v1”调用替换为标准 API 调用。
那么呢?
构建一个用于水果分类的深度学习模型。使用ResNet50进行图像分类,并利用TensorFlow和Keras来实现这一目标。
以下是详细的步骤:
-
环境准备:
- 安装必要的库。
- 下载并组织数据集。
-
数据预处理:
- 将数据集分为训练集、验证集和测试集。
- 格式化标签文件以便于模型使用。
-
模型定义与训练:
- 使用ResNet50进行图像分类。
-
评估与可视化:
- 评估模型性能。
- 可视化结果。
环境准备
首先,我们需要安装必要的库。您可以使用以下命令来设置环境:
pip install tensorflow keras numpy pandas matplotlib scikit-learn opencv-python-headless lxml
数据预处理
假设您的数据集已经下载并存储在datasets/fruits_262_classes
目录中,其中包含多个子目录,每个子目录代表一个水果类别,并且每个子目录中包含相应的图像文件。
,我们将编写一个脚本来检查数据集的有效性,并将数据集分为训练集、验证集和测试集。
[<title="Data Preparation Script for Fruits Classification">]
import os
import shutil
import random
from pathlib import Path
# Define paths
base_path = Path('datasets/fruits_262_classes')
train_dir = base_path / 'train'
val_dir = base_path / 'val'
test_dir = base_path / 'test'
os.makedirs(train_dir, exist_ok=True)
os.makedirs(val_dir, exist_ok=True)
os.makedirs(test_dir, exist_ok=True)
# Function to split dataset into train, val, test sets
def split_dataset(base_path, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1):
for class_name in os.listdir(base_path):
class_path = base_path / class_name
if not class_path.is_dir():
continue
images = list(class_path.glob('*'))
random.shuffle(images)
num_images = len(images)
num_train = int(num_images * train_ratio)
num_val = int(num_images * val_ratio)
num_test = num_images - num_train - num_val
train_images = images[:num_train]
val_images = images[num_train:num_train+num_val]
test_images = images[num_train+num_val:]
# Create directories for each class in train, val, test
train_class_dir = train_dir / class_name
val_class_dir = val_dir / class_name
test_class_dir = test_dir / class_name
os.makedirs(train_class_dir, exist_ok=True)
os.makedirs(val_class_dir, exist_ok=True)
os.makedirs(test_class_dir, exist_ok=True)
# Move images to respective directories
for img in train_images:
shutil.move(str(img), str(train_class_dir))
for img in val_images:
shutil.move(str(img), str(val_class_dir))
for img in test_images:
shutil.move(str(img), str(test_class_dir))
split_dataset(base_path)
print("Dataset split into train, val, and test sets.")
模型定义与训练
我们将使用ResNet50进行图像分类。以下是训练脚本 train_classification.py
:
[<title="Training Script for Fruits Classification using ResNet50">]
import tensorflow as tf
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
# Paths
train_dir = 'datasets/fruits_262_classes/train'
validation_dir = 'datasets/fruits_262_classes/val'
# Data generators
datagen = ImageDataGenerator(preprocessing_function=preprocess_input)
train_generator = datagen.flow_from_directory(
train_dir,
target_size=(256, 256),
batch_size=32,
class_mode='categorical'
)
validation_generator = datagen.flow_from_directory(
validation_dir,
target_size=(256, 256),
batch_size=32,
class_mode='categorical'
)
# Load ResNet50 model without top layer
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(256, 256, 3))
# Add custom layers on top
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
predictions = Dense(train_generator.num_classes, activation='softmax')(x)
# Combine with base model
model = Model(inputs=base_model.input, outputs=predictions)
# Freeze convolutional base
for layer in base_model.layers:
layer.trainable = False
# Compile the model
model.compile(optimizer=Adam(lr=0.0001), loss='categorical_crossentropy', metrics=['accuracy'])
# Callbacks
checkpoint = ModelCheckpoint('best_fruit_model.h5', monitor='val_loss', save_best_only=True, mode='min')
early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
# Train the model
history = model.fit(
train_generator,
steps_per_epoch=train_generator.samples // train_generator.batch_size,
validation_data=validation_generator,
validation_steps=validation_generator.samples // validation_generator.batch_size,
epochs=50,
callbacks=[checkpoint, early_stopping],
verbose=1
)
# Save training history
import numpy as np
np.save('fruit_training_history.npy', history.history)
评估与可视化
编写评估脚本 evaluate_classification.py
来计算准确率、混淆矩阵和其他指标,并绘制相应的图表。
[<title="Evaluation Script for Fruits Classification">]
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
# Paths
test_dir = 'datasets/fruits_262_classes/test'
# Data generator
datagen = ImageDataGenerator(preprocessing_function=preprocess_input)
test_generator = datagen.flow_from_directory(
test_dir,
target_size=(256, 256),
batch_size=32,
class_mode='categorical',
shuffle=False
)
# Load the best model
model = tf.keras.models.load_model('best_fruit_model.h5')
# Predictions
y_pred = model.predict(test_generator)
y_pred_classes = np.argmax(y_pred, axis=1)
# True labels
y_true = test_generator.classes
# Class names
class_names = test_generator.class_indices.keys()
# Classification report
class_report = classification_report(y_true, y_pred_classes, target_names=class_names)
print(class_report)
# Confusion matrix
conf_mat = confusion_matrix(y_true, y_pred_classes)
plt.figure(figsize=(20, 16))
sns.heatmap(conf_mat, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.title('Confusion Matrix')
plt.xticks(rotation=90)
plt.yticks(rotation=0)
plt.savefig('fruit_confusion_matrix.png')
plt.show()
# Training history
history = np.load('fruit_training_history.npy', allow_pickle=True).item()
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history['loss'], label='Train Loss')
plt.plot(history['val_loss'], label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Loss')
plt.subplot(1, 2, 2)
plt.plot(history['accuracy'], label='Train Accuracy')
plt.plot(history['val_accuracy'], label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Accuracy')
plt.tight_layout()
plt.savefig('fruit_training_history.png')
plt.show()
用户界面
我们将使用 PyQt5 创建一个简单的 GUI 来加载和运行模型进行实时预测。以下是用户界面脚本 ui.py
:
[<title="PyQt5 Main Window for Fruits Classification">]
import sys
import cv2
import numpy as np
from PyQt5.QtWidgets import QApplication, QMainWindow, QLabel, QPushButton, QVBoxLayout, QWidget, QFileDialog
from PyQt5.QtGui import QImage, QPixmap
from PyQt5.QtCore import Qt, QTimer
import tensorflow as tf
from tensorflow.keras.applications.resnet50 import preprocess_input
from PIL import ImageDraw, ImageFont
# Load model
classification_model = tf.keras.models.load_model('best_fruit_model.h5')
# Load class names
with open('datasets/fruits_262_classes/train/class_names.txt', 'r') as f:
class_names = [line.strip() for line in f.readlines()]
class MainWindow(QMainWindow):
def __init__(self):
super().__init__()
self.setWindowTitle("Fruits Classification System")
self.setGeometry(100, 100, 800, 600)
self.initUI()
def initUI(self):
self.central_widget = QWidget()
self.setCentralWidget(self.central_widget)
self.layout = QVBoxLayout()
self.image_label = QLabel(self)
self.image_label.setAlignment(Qt.AlignCenter)
self.layout.addWidget(self.image_label)
self.load_image_button = QPushButton("Load Image", self)
self.load_image_button.clicked.connect(self.load_image)
self.layout.addWidget(self.load_image_button)
self.start_prediction_button = QPushButton("Start Prediction", self)
self.start_prediction_button.clicked.connect(self.start_prediction)
self.layout.addWidget(self.start_prediction_button)
self.stop_prediction_button = QPushButton("Stop Prediction", self)
self.stop_prediction_button.clicked.connect(self.stop_prediction)
self.layout.addWidget(self.stop_prediction_button)
self.central_widget.setLayout(self.layout)
self.image_path = None
self.timer = QTimer()
self.timer.timeout.connect(self.update_frame)
def load_image(self):
options = QFileDialog.Options()
file_name, _ = QFileDialog.getOpenFileName(self, "QFileDialog.getOpenFileName()", "", "Images (*.png *.jpg *.jpeg);;All Files (*)", options=options)
if file_name:
self.image_path = file_name
self.display_image(file_name)
def display_image(self, path):
pixmap = QPixmap(path)
scaled_pixmap = pixmap.scaled(self.image_label.width(), self.image_label.height(), Qt.KeepAspectRatio)
self.image_label.setPixmap(scaled_pixmap)
def start_prediction(self):
if self.image_path is not None and not self.timer.isActive():
self.timer.start(30) # Update frame every 30 ms
def stop_prediction(self):
if self.timer.isActive():
self.timer.stop()
self.image_label.clear()
def update_frame(self):
original_image = cv2.imread(self.image_path)
image_rgb = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
# Resize image
resized_image = cv2.resize(image_rgb, (256, 256))
preprocessed_image = preprocess_input(np.expand_dims(resized_image, axis=0))
# Prediction
prediction = classification_model.predict(preprocessed_image)
predicted_class_index = np.argmax(prediction)
predicted_class_name = class_names[predicted_class_index]
confidence = prediction[0][predicted_class_index]
# Draw bounding box and text
font = cv2.FONT_HERSHEY_SIMPLEX
cv2.putText(image_rgb, f'{predicted_class_name} ({confidence:.2f})', (10, 30), font, 0.9, (0, 255, 0), 2)
h, w, ch = image_rgb.shape
bytes_per_line = ch * w
qt_image = QImage(image_rgb.data, w, h, bytes_per_line, QImage.Format_RGB888)
pixmap = QPixmap.fromImage(qt_image)
scaled_pixmap = pixmap.scaled(self.image_label.width(), self.image_label.height(), Qt.KeepAspectRatio)
self.image_label.setPixmap(scaled_pixmap)
if __name__ == "__main__":
app = QApplication(sys.argv)
window = MainWindow()
window.show()
sys.exit(app.exec_())
请确保将路径替换为您实际的路径。
使用说明
-
配置路径:
- 确保
datasets/fruits_262_classes
目录结构正确,并且包含train
、val
和test
子目录。 - 确保
best_fruit_model.h5
是训练好的 ResNet50 模型权重路径。
- 确保
-
运行脚本:
- 在终端中运行
data_preparation.py
脚本来检查数据集的有效性并创建训练、验证和测试集。 - 在终端中运行
train_classification.py
脚本来训练图像分类模型。 - 在终端中运行
evaluate_classification.py
来评估图像分类模型性能。 - 在终端中运行
ui.py
来启动 GUI 应用程序。
- 在终端中运行
-
注意事项:
- 确保所有必要的工具箱已安装,特别是 TensorFlow 和 PyQt5。
- 根据需要调整参数,如
epochs
和batch_size
。
示例
假设您的数据文件夹结构如下:
datasets/
└── fruits_262_classes/
├── train/
│ ├── apple/
│ ├── banana/
│ └── ...
├── val/
│ ├── apple/
│ ├── banana/
│ └── ...
└── test/
├── apple/
├── banana/
└── ...
并且每个数据集中包含相应的图像文件。运行 ui.py
后,您可以点击按钮来加载图像并进行水果分类。