如何使用YOLOv8进行水稻病害分类的数据集训练,水稻病害分类数据集 包括4个类别共5932张图像:白叶枯病、稻瘟病、东格鲁病、褐斑病

水稻病害分类数据集,包括4个类别共5932张图像:白叶枯病、稻瘟病、东格鲁病、褐斑病。在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
好的,以下是详细的文档格式,包含了如何使用YOLOv8进行水稻病害分类的数据集训练。由于这是一个分类任务而不是目标检测任务,我们将使用YOLOv8的分类功能来进行训练。

使用 YOLOv8 训练水稻病害分类

数据集信息
  • 类别: 4类 (白叶枯病, 稻瘟病, 东格鲁病, 褐斑病)
  • 图片数量: 共5932张
  • 分辨率: 假设为统一的分辨率(例如 224x224)
  • 标签格式:
    • 分类格式 (txt 或 CSV)
步骤概述
  1. 数据集准备
  2. 创建数据集配置文件 (data.yaml)
  3. 分割数据集
  4. 训练模型
  5. 评估模型
  6. 可视化训练结果
  7. 清理临时文件
  8. 推理和显示结果
详细步骤
1. 数据集准备

确保你的数据集已经按照上述格式准备好,并且包含 imageslabels 目录。

rice_disease_classification/
├── datasets/
│   └── rice_disease_dataset/
│       ├── images/
│       │   ├── image1.jpg
│       │   ├── image2.jpg
│       │   └── ...
│       └── labels/
│           ├── image1.txt
│           ├── image2.txt
│           └── ...
└── main.py
2. 创建数据集配置文件 (data.yaml)

创建一个 data.yaml 文件来配置数据集路径和类别信息。

train: ./datasets/rice_disease_dataset/train/images
val: ./datasets/rice_disease_dataset/val/images

nc: 4  # 类别数量
names: ['white_leaf_blight', 'blast', 'tungro', 'brown_spot']
3. 分割数据集

将数据集分割成训练集和验证集。

import os
import random
from pathlib import Path
import shutil

def split_dataset(data_dir, train_ratio=0.8):
    images = list(Path(data_dir).glob('*.jpg'))
    random.shuffle(images)
    
    num_train = int(len(images) * train_ratio)
    train_images = images[:num_train]
    val_images = images[num_train:]
    
    train_dir = Path(data_dir).parent / 'train'
    val_dir = Path(data_dir).parent / 'val'
    
    train_img_dir = train_dir / 'images'
    train_label_dir = train_dir / 'labels'
    val_img_dir = val_dir / 'images'
    val_label_dir = val_dir / 'labels'
    
    train_img_dir.mkdir(parents=True, exist_ok=True)
    train_label_dir.mkdir(parents=True, exist_ok=True)
    val_img_dir.mkdir(parents=True, exist_ok=True)
    val_label_dir.mkdir(parents=True, exist_ok=True)
    
    for img in train_images:
        label_path = img.with_suffix('.txt')
        shutil.copy(img, train_img_dir / img.name)
        shutil.copy(label_path, train_label_dir / label_path.name)
    
    for img in val_images:
        label_path = img.with_suffix('.txt')
        shutil.copy(img, val_img_dir / img.name)
        shutil.copy(label_path, val_label_dir / label_path.name)

# 使用示例
split_dataset('./datasets/rice_disease_dataset/images')
4. 训练模型

使用YOLOv8进行分类训练。

import torch
from ultralytics import YOLO

# 设置随机种子以保证可重复性
torch.manual_seed(42)

# 定义数据集路径
dataset_config = 'data.yaml'

# 加载预训练的YOLOv8n-c模型(用于分类)
model = YOLO('yolov8n-cls.pt')

# 训练模型
results = model.train(
    data=dataset_config,
    epochs=100,
    imgsz=224,
    batch=16,
    name='rice_disease_classification',
    project='runs/classify'
)

# 评估模型
metrics = model.val()

# 保存最佳模型权重
best_model_weights = 'runs/classify/rice_disease_classification/weights/best.pt'
print(f"Best model weights saved to {best_model_weights}")
5. 可视化训练结果

可视化训练结果。

from ultralytics import YOLO

# 加载训练好的模型
model = YOLO('runs/classify/rice_disease_classification/weights/best.pt')

# 可视化训练结果
model.plot_results(save=True, save_dir='runs/classify/rice_disease_classification')
6. 清理临时文件

清理不必要的临时文件。

import shutil

def clean_temp_files(project_dir):
    temp_dirs = [
        f'{project_dir}/wandb',
        f'{project_dir}/cache'
    ]
    
    for dir_path in temp_dirs:
        if os.path.exists(dir_path):
            shutil.rmtree(dir_path)
            print(f"Removed directory: {dir_path}")

# 使用示例
clean_temp_files('runs/classify/rice_disease_classification')
7. 推理和显示结果

推理和显示结果。

from ultralytics import YOLO
import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

# 主函数
def main():
    # 加载模型
    model = YOLO('runs/classify/rice_disease_classification/weights/best.pt')
    
    # 测试图片路径
    test_image_path = './datasets/rice_disease_dataset/images/test_image.jpg'
    
    # 进行预测
    results = model.predict(test_image_path, conf=0.5)[0]
    
    # 获取预测结果
    predicted_class_id = int(results.probs.argmax())
    predicted_class_name = model.names[predicted_class_id]
    confidence = float(results.probs.max())
    
    # 打印预测结果
    print(f"Predicted Class: {predicted_class_name}, Confidence: {confidence:.2f}")
    
    # 显示图片
    image = cv2.imread(test_image_path)
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    # 在图片上添加预测标签
    font = cv2.FONT_HERSHEY_SIMPLEX
    text = f"{predicted_class_name} ({confidence:.2f})"
    position = (10, 30)
    font_scale = 1
    color = (0, 255, 0)
    thickness = 2
    
    cv2.putText(image_rgb, text, position, font, font_scale, color, thickness)
    
    plt.imshow(image_rgb)
    plt.axis('off')
    plt.show()

if __name__ == "__main__":
    main()
运行脚本

在终端中运行以下命令来执行整个流程:

python main.py

总结

以上文档包含了从数据集准备、模型训练、评估、可视化训练结果、清理临时文件到推理和显示结果的所有步骤。希望这些详细的信息和代码能够帮助你顺利实施和优化你的水稻病害分类系统。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值