水稻病害分类数据集,包括4个类别共5932张图像:白叶枯病、稻瘟病、东格鲁病、褐斑病。
好的,以下是详细的文档格式,包含了如何使用YOLOv8进行水稻病害分类的数据集训练。由于这是一个分类任务而不是目标检测任务,我们将使用YOLOv8的分类功能来进行训练。
使用 YOLOv8 训练水稻病害分类
数据集信息
- 类别: 4类 (白叶枯病, 稻瘟病, 东格鲁病, 褐斑病)
- 图片数量: 共5932张
- 分辨率: 假设为统一的分辨率(例如 224x224)
- 标签格式:
- 分类格式 (txt 或 CSV)
步骤概述
- 数据集准备
- 创建数据集配置文件 (
data.yaml
) - 分割数据集
- 训练模型
- 评估模型
- 可视化训练结果
- 清理临时文件
- 推理和显示结果
详细步骤
1. 数据集准备
确保你的数据集已经按照上述格式准备好,并且包含 images
和 labels
目录。
rice_disease_classification/
├── datasets/
│ └── rice_disease_dataset/
│ ├── images/
│ │ ├── image1.jpg
│ │ ├── image2.jpg
│ │ └── ...
│ └── labels/
│ ├── image1.txt
│ ├── image2.txt
│ └── ...
└── main.py
2. 创建数据集配置文件 (data.yaml
)
创建一个 data.yaml
文件来配置数据集路径和类别信息。
train: ./datasets/rice_disease_dataset/train/images
val: ./datasets/rice_disease_dataset/val/images
nc: 4 # 类别数量
names: ['white_leaf_blight', 'blast', 'tungro', 'brown_spot']
3. 分割数据集
将数据集分割成训练集和验证集。
import os
import random
from pathlib import Path
import shutil
def split_dataset(data_dir, train_ratio=0.8):
images = list(Path(data_dir).glob('*.jpg'))
random.shuffle(images)
num_train = int(len(images) * train_ratio)
train_images = images[:num_train]
val_images = images[num_train:]
train_dir = Path(data_dir).parent / 'train'
val_dir = Path(data_dir).parent / 'val'
train_img_dir = train_dir / 'images'
train_label_dir = train_dir / 'labels'
val_img_dir = val_dir / 'images'
val_label_dir = val_dir / 'labels'
train_img_dir.mkdir(parents=True, exist_ok=True)
train_label_dir.mkdir(parents=True, exist_ok=True)
val_img_dir.mkdir(parents=True, exist_ok=True)
val_label_dir.mkdir(parents=True, exist_ok=True)
for img in train_images:
label_path = img.with_suffix('.txt')
shutil.copy(img, train_img_dir / img.name)
shutil.copy(label_path, train_label_dir / label_path.name)
for img in val_images:
label_path = img.with_suffix('.txt')
shutil.copy(img, val_img_dir / img.name)
shutil.copy(label_path, val_label_dir / label_path.name)
# 使用示例
split_dataset('./datasets/rice_disease_dataset/images')
4. 训练模型
使用YOLOv8进行分类训练。
import torch
from ultralytics import YOLO
# 设置随机种子以保证可重复性
torch.manual_seed(42)
# 定义数据集路径
dataset_config = 'data.yaml'
# 加载预训练的YOLOv8n-c模型(用于分类)
model = YOLO('yolov8n-cls.pt')
# 训练模型
results = model.train(
data=dataset_config,
epochs=100,
imgsz=224,
batch=16,
name='rice_disease_classification',
project='runs/classify'
)
# 评估模型
metrics = model.val()
# 保存最佳模型权重
best_model_weights = 'runs/classify/rice_disease_classification/weights/best.pt'
print(f"Best model weights saved to {best_model_weights}")
5. 可视化训练结果
可视化训练结果。
from ultralytics import YOLO
# 加载训练好的模型
model = YOLO('runs/classify/rice_disease_classification/weights/best.pt')
# 可视化训练结果
model.plot_results(save=True, save_dir='runs/classify/rice_disease_classification')
6. 清理临时文件
清理不必要的临时文件。
import shutil
def clean_temp_files(project_dir):
temp_dirs = [
f'{project_dir}/wandb',
f'{project_dir}/cache'
]
for dir_path in temp_dirs:
if os.path.exists(dir_path):
shutil.rmtree(dir_path)
print(f"Removed directory: {dir_path}")
# 使用示例
clean_temp_files('runs/classify/rice_disease_classification')
7. 推理和显示结果
推理和显示结果。
from ultralytics import YOLO
import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
# 主函数
def main():
# 加载模型
model = YOLO('runs/classify/rice_disease_classification/weights/best.pt')
# 测试图片路径
test_image_path = './datasets/rice_disease_dataset/images/test_image.jpg'
# 进行预测
results = model.predict(test_image_path, conf=0.5)[0]
# 获取预测结果
predicted_class_id = int(results.probs.argmax())
predicted_class_name = model.names[predicted_class_id]
confidence = float(results.probs.max())
# 打印预测结果
print(f"Predicted Class: {predicted_class_name}, Confidence: {confidence:.2f}")
# 显示图片
image = cv2.imread(test_image_path)
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# 在图片上添加预测标签
font = cv2.FONT_HERSHEY_SIMPLEX
text = f"{predicted_class_name} ({confidence:.2f})"
position = (10, 30)
font_scale = 1
color = (0, 255, 0)
thickness = 2
cv2.putText(image_rgb, text, position, font, font_scale, color, thickness)
plt.imshow(image_rgb)
plt.axis('off')
plt.show()
if __name__ == "__main__":
main()
运行脚本
在终端中运行以下命令来执行整个流程:
python main.py
总结
以上文档包含了从数据集准备、模型训练、评估、可视化训练结果、清理临时文件到推理和显示结果的所有步骤。希望这些详细的信息和代码能够帮助你顺利实施和优化你的水稻病害分类系统。