使用YOLOv8分类模型进行迁移训练对农作物病虫害识别分类

YOLO模型能够进行图像检测、分类、分割、追踪等多种任务,我们将使用YOLO的分类模型进行一个十分基础的苹果叶片病虫害识别。

本文使用的pytorch版本为2.3.0+cu121

YOLOv8模型准备

首先为了能够调用yolo模型,我们首先要安装ultralytics库,直接使用pip安装即可:

pip install ultralytics

然后我还需要到官网下载yolo的预训练模型

yolov8预训练的分类模型主要有n、s、m、l、x五种,五种模型预测准确率不同,相应的内存与运算时间也不同,yolov8n内存最小、运算时间最快,但相应的准确率也较低,这里我选择的是yolov8m-cls预训练模型,大家可以根据自己的设备性能选择适合自己的预训练模型。

数据准备

本文使用的苹果叶片数据集总共有10种类别,每一种类别都代表着一种病害

yolo模型所要求的数据集格式类似于:

- 数据集/
  - train/
    - class1/
    - class2/
  - val/
    - class1/
    - class2/
  - test/
    - class1/
    - class2/

 我们需要把数据集划分为train、valid和test三个文件夹,然后在每个文件夹下,再把图片数据划分到相应的类别(class1、class2)文件夹中。

我在这里使用的苹果叶片数据集已经以8:1:1的比例划分好了训练集、验证集和测试集,下载链接等下会放在文末。

训练模型

下面我们将开始训练模型

首先,导入需要的库

from ultralytics import YOLO
from torchvision.datasets import ImageFolder
from torchvision.transforms import ToTensor, Resize, Compose
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score
import os

本人在训练的时候,遇到过 Error #15: Initializing libiomp5md.dll, but found libiomp5md.dll already initialized这样的报错,据说是因为多个科学计算库不兼容导致的,因此我在文件中添加了下面代码,如果大家没有这样的问题可以不用理会

os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

随后加载刚刚下好的预训练模型

model = YOLO('yolov8m-cls.pt')

然后就可以开始训练模型了

model.train(data="apple_dataset", epochs=50, imgsz=640) 

data参数填你的数据集目录,注意:这里的apple_dataset应该是包含train、val、test三个文件夹的。

epochs参数为训练的周期,这里设置为50

imgsz为用来训练模型的图像大小,即模型会将图片重新缩放为imgsz的大小,这里我设置为640。imgsz会在一定程度上影响模型的预测准确率,imgsz越大,相应的准确率也会变高,但相应的训练速度也会变慢,所需内存也会变多。如果内存不足的话,大家可以把imgsz设置为448或256。

随后模型就会开始训练了

模型训练完成后会在当前目录下创建一个runs文件夹,里面的classify中记录了每次分裂模型训练的有关文件。

其中args.yaml记录了这次训练所设置的模型参数;confusion_matrix则是混淆矩阵,可以用来评估模型;events.out.tfevents开头的文件则是tensorboard的日志文件,可以通过加载这个文件来查看训练过程中超参数、模型损失等数据的变化;其他的还包括一些训练过程中每个batch具体的训练图片以及验证过程中的验证结果。

weight文件夹则是模型最后一个epoch时的checkpoint和模型表现最好的checkpoint。

测试模型

模型训练完毕后,我们可以看看模型在测试集上的表现

test_dir = "apple_dataset\\test"
transform = Compose([
    Resize((640, 640)),
    ToTensor(),
])
test_datasets = ImageFolder(test_dir, transform=transform)
test_dataloader = DataLoader(test_datasets, batch_size=1, shuffle=False)

preds = []
labels = []

for batch in test_dataloader:
    img, label = batch
    result = model.predict(img, verbose=False)
    preds.append(result[0].probs.top1)
    labels.append(label.item())

# 计算准确率
accuracy = accuracy_score(preds, labels)
print(f'Accuracy: {accuracy * 100:.2f} %')

 

可以看到模型在最后测试集的准确率上由99%,说明结果还不错

我们还可以随便选择一张图片,看看模型具体的预测结果

predict_img = "apple_dataset\\test\\Black rot\\Black rot (29).JPG"
predict_result = model.predict(predict_img, imgsz=640)

labels = predict_result[0].names
predictt_label = predict_result[0].probs.top5
predict_probility = predict_result[0].probs.top5conf.cpu().numpy()

for i in range(5):
    print(f"是{labels[predictt_label[i]]}的概率为{predict_probility[i] * 100:.2f}%")

 最后的预测结果符合真实结果。

完整代码如下:

from ultralytics import YOLO
from torchvision.datasets import ImageFolder
from torchvision.transforms import ToTensor, Resize, Compose
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score
import os

//os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

model = YOLO('yolov8m-cls.pt')

model.train(data="apple_dataset", epochs=50, imgsz=640) 

// 预测数据
test_dir = "apple_dataset\\test"
transform = Compose([
    Resize((640, 640)),
    ToTensor(),
])
test_datasets = ImageFolder(test_dir, transform=transform)
test_dataloader = DataLoader(test_datasets, batch_size=1, shuffle=False)

preds = []
labels = []

for batch in test_dataloader:
    img, label = batch
    result = model.predict(img, verbose=False)
    preds.append(result[0].probs.top1)
    labels.append(label.item())

# 计算准确率
accuracy = accuracy_score(preds, labels)
print(f'Accuracy: {accuracy * 100:.2f} %')


// 预测图片
predict_img = "apple_dataset\\test\\Black rot\\Black rot (29).JPG"
predict_result = model.predict(predict_img, imgsz=640)

labels = predict_result[0].names
predictt_label = predict_result[0].probs.top5
predict_probility = predict_result[0].probs.top5conf.cpu().numpy()

for i in range(5):
    print(f"是{labels[predictt_label[i]]}的概率为{predict_probility[i] * 100:.2f}%")

使用的苹果叶片数据集:

链接: https://pan.baidu.com/s/1_FO_005Q6i-nzR-Apx2I2Q?pwd=8mig 提取码: 8mig 

本文只是使用yolo模型完成了一个简单的数据分类任务,大家感兴趣的话可以深入了解一下yolo模型。除了分类任务外,yolo模型在目标检测和目标追踪等任务上也具有不错的表现。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值