基于迁移学习的水果分类和成熟度识别系统
一、数据集收集
本系统使用的数据集是来自Kaggle的“Fresh and Rotten Classification”,下载方式如下:
import kagglehub
# Download latest version
path = kagglehub.dataset_download("swoyam2609/fresh-and-stale-classification")
print("Path to dataset files:", path)
数据集中包含了两个文件夹,分别是整理好的训练集和测试集,需要注意的是,并没有整理验证集,因此在训练时需要手动或使用代码区分训练集和验证集
二、模型训练
1. 加载数据集
# 划分数据集为训练集和验证集
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) # random_split函数可以将数据集随机划分为两个子数据集,并且不会影响到原始数据和标签的对应关系
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) # 用shuffle打乱数据集有利于在训练时增加模型的泛化能力,防止出现过拟合现象
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False) # 验证集不需要打乱数据,这样每一轮验证时数据的顺序是一样的,有利于评估模型性能
2. 构建模型
迁移学习,顾名思义就是利用已经学到的经验来学习新的知识,将某个领域或任务上学习到的知识或模式应用到不同但相关的领域或问题中。而在计算机视觉中,就是利用别人已经训练过的模型,来对类似的目标进行分类检测任务。在本系统中,导入Resnet50预训练模型进行迁移学习,这样能够极大地缩短训练所需要的时间,减少计算资源消耗。
本次需要识别的对象有14种,而Resnet50模型的全连接层有1000个输出,所以本系统中建立了一个新的全连接层来替换Resnet50中的这一层。
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
num_ftrs = model.fc.in_features # 提取预训练模型的全连接层的输入特征数
model.fc = nn.Linear(num_ftrs, len(dataset.classes)) # 建立新的全连接层,替换预训练模型的这一层
3. 训练模型
Resnet50在大型数据集ImageNet上进行了训练,已经学到了丰富的特征表示。通过冻结这些层、先单独训练全连接层,可以保留这些特征提取能力,而不需要训练整个模型。单独对全连接层训练5个epoch后,解冻其他层进行联合训练,可以逐步调整模型的权重。
需要注意的是,在计算机视觉中,为了统一图像格式,方便模型处理,通常会先进行图像预处理,对图像进行裁剪、增强和归一化等操作。
# 数据预处理和增强
transform = transforms.Compose([
transforms.Resize((150, 150)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
为了防止模型出现过拟合的情况,还可以在预处理时对图像进行随机角度的旋转、随机的水平翻转等,加强训练集的随机性。
通过检查模型训练过程中模型在验证集的准确率,我们可以初步判断模型的分类能力,在得到较为满意的模型后,我们可以使用测试集进一步对模型进行测试。
三、测试模型
在测试之后,可以通过预测准确率、混淆矩阵等来判断模型的性能。
到此为止,算是获得了一个可以用于给水果进行分类的模型,接下来就要开始尝试对自己拍摄的水果进行分类。
四、图像处理
在给果进行分类时,发现几个问题:
- 拍照时尽量不能拍到阴影,阴影部分在边缘检测时很容易造成边缘模糊
- 拍照时要在白光照耀下拍摄,其他颜色的光会造成色差
- 最好在白色背景板下拍摄,减少其他颜色的物品对图像处理造成影响
在保证能够拍摄正常色彩的水果图片后,我们需要做的就是通过图像处理,从图片中准确地找到含有水果的那一部分然后裁剪下来。在图像处理中,这就是主体识别、背景分割。代码能够把含有水果的那部分切割得越清楚,背景元素越少,模型的识别能力就越精确。
在本系统中使用OpenCV库进行图像识别。使用以下代码即可导入:
import cv2
图像处理部分的代码如下所示,这部分内容本人也是第一次接触,感觉写得乱七八糟的,放出来请各位指点。
def process_image(image_path):
image = cv2.imread(image_path) # 使用OpenCV读取图片
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) # 转换为灰度图
# cv2.namedWindow("gray Image", 0)
# cv2.resizeWindow("gray Image", 640, 640)
# cv2.imshow("gray Image", gray)
# cv2.waitKey(0)
# cv2.destroyAllWindows()
# # 计算灰度图的直方图
# hist = cv2.calcHist([gray], [0], None, [256], [0, 256])
# # 绘制直方图
# plt.figure()
# plt.title("Gray Image Histogram")
# plt.xlabel("Pixel Value")
# plt.ylabel("Frequency")
# plt.plot(hist)
# plt.xlim([0, 256])
# plt.show()
# 使用形态学操作减弱阴影
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (15, 15))
top_hat = cv2.morphologyEx(gray, cv2.MORPH_TOPHAT, kernel)
shadow_removed = cv2.add(gray, top_hat)
# 显示阴影减弱后的图像
# cv2.namedWindow("Shadow Removed Image", 0)
# cv2.resizeWindow("Shadow Removed Image", 640, 640)
# cv2.imshow("Shadow Removed Image", shadow_removed)
# cv2.waitKey(0)
# cv2.destroyAllWindows()
# 对灰度图进行对比度增强
alpha = 1.5 # 对比度控制(1.0-3.0)
beta = 10 # 亮度控制(0-100)
enhanced_gray = cv2.convertScaleAbs(shadow_removed, alpha=alpha, beta=beta)
# 显示对比度增强后的图像
# cv2.namedWindow("Enhanced Contrast Image", 0)
# cv2.resizeWindow("Enhanced Contrast Image", 640, 640)
# cv2.imshow("Enhanced Contrast Image", enhanced_gray)
# cv2.waitKey(0)
# cv2.destroyAllWindows()
# 使用闭运算对图像进行处理
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
closed_image = cv2.morphologyEx(enhanced_gray, cv2.MORPH_CLOSE, kernel)
# 显示闭运算后的图像
# cv2.namedWindow("Closed Image", 0)
# cv2.resizeWindow("Closed Image", 640, 640)
# cv2.imshow("Closed Image", closed_image)
# cv2.waitKey(0)
# cv2.destroyAllWindows()
# 对灰度图进行高通滤波
high_passed = high_pass_filter(closed_image, 50) # 使用30作为截止频率
# cv2.namedWindow("High Pass Filtered Image", 0)
# cv2.resizeWindow("High Pass Filtered Image", 640, 640)
# cv2.imshow("High Pass Filtered Image", high_passed)
# cv2.waitKey(0)
# cv2.destroyAllWindows()
# 归一化并转换为8位无符号整数格式
high_passed = cv2.normalize(high_passed, None, 0, 255, cv2.NORM_MINMAX)
high_passed = np.uint8(high_passed)
# 使用Otsu阈值化方法进行二值化
_, otsu_thresh = cv2.threshold(high_passed, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
# cv2.namedWindow("Otsu Threshold Image", 0)
# cv2.resizeWindow("Otsu Threshold Image", 640, 640)
# cv2.imshow("Otsu Threshold Image", otsu_thresh)
# cv2.waitKey(0)
# cv2.destroyAllWindows()
edges = cv2.Canny(otsu_thresh, 0, 200, apertureSize=3) # 使用Canny边缘检测器
# cv2.namedWindow("edge Image", 0)
# cv2.resizeWindow("edge Image", 640, 640)
# cv2.imshow("edge Image", edges)
# cv2.waitKey(0)
# cv2.destroyAllWindows()
# 进行膨胀操作
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
dilated_edges = cv2.dilate(edges, kernel, iterations=2)
# cv2.namedWindow("Dilated Edge Image", 0)
# cv2.resizeWindow("Dilated Edge Image", 640, 640)
# cv2.imshow("Dilated Edge Image", dilated_edges)
# cv2.waitKey(0)
# cv2.destroyAllWindows()
contours, _ = cv2.findContours(dilated_edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
max_contour = max(contours, key=cv2.contourArea) # 找到最大轮廓
cv2.drawContours(image, [max_contour], -1, (0, 255, 0), 2) # 绘制最大轮廓
# 显示图像
cv2.namedWindow("Image", 0)
cv2.resizeWindow("Image", 640, 640)
cv2.imshow("Image", image)
cv2.waitKey(0)
cv2.destroyAllWindows()
# 获取最大轮廓的边界框
x, y, w, h = cv2.boundingRect(max_contour)
cropped_image = image[y:y+h, x:x+w] # 裁剪出最大轮廓
# 显示裁剪后的图像
cv2.namedWindow("Cropped Image", 0)
cv2.resizeWindow("Cropped Image", 640, 640)
cv2.imshow("Cropped Image", cropped_image)
cv2.waitKey(0)
cv2.destroyAllWindows()
cropped_image_path = r'pics/cropped_image.jpg'
cv2.imwrite(cropped_image_path, cropped_image) # 保存裁剪后的图像
return cropped_image_path
大体步骤如下:
- 将图片转换为灰度图。这是为了将需要处理的颜色通道减少。
- 为了查看图片中的噪声是如何分布的,计算并绘制灰度图的直方图,发现大多数噪声都是低于100Hz的,所以后续写了一个高通滤波器来对图像进行滤噪。
- 用顶帽操作减小灰度图的阴影影响
- 对灰度图进行对比度增强
- 使用闭运算消除图像中的小黑点,平滑图像边缘
- 高通滤波
- 对图像进行归一化处理并转换为8位无符号整数格式
- 使用otsu阈值进行二值化,得到黑图中的白色轮廓即为水果的轮廓
- 使用canny函数进行轮廓提取
- 为了防止出现部分轮廓未闭合未连续的情况,对提取出的轮廓进行膨胀操作
- 找到最大的闭合轮廓
- 获取最大轮廓的边界框
- 在原图中按照边界框进行裁剪,获得图像中包含水果的那一部分
- 返回裁剪图片的保存路径,作为预测函数的输入参数
五、图像识别并分类
预测函数就是Pytorch的经典写法。
# 定义预测函数
def predict(CropedPicture_Path):
image = Image.open(image_path)
image = transform(image).unsqueeze(0).to(device)
# 进行预测
model.eval()
with torch.no_grad():
outputs = model(image)
_, predicted = torch.max(outputs, 1)
return predicted.item()
至此,一个可以使用的水果分类和成熟度识别系统就完成了。