使用YOLOv8训练DOTA遥感数据集的目标检测任务。从数据准备、模型训练、评估和结果可视化等多个步骤

使用YOLOv8训练DOTA数据集的目标检测任务。从数据准备、模型训练、评估和结果可视化等多个步骤

在这里插入图片描述
在这里插入图片描述
好的,针对DOTA数据集的任务,我们将从数据准备、模型训练、评估和可视化等方面详细介绍整个流程。DOTA数据集(Dataset for Object Detection in Aerial Images)包含2806张4000 × 4000的航拍图像,总共包含188282个目标,涉及14个类别。我们将使用YOLOv8进行目标检测任务,并提供一些常用的脚本来帮助你更高效地处理数据和训练模型。

1. 环境准备

首先,确保你已经安装了必要的库和工具。你可以使用以下命令安装所需的库:

pip install ultralytics
pip install torch torchvision
pip install opencv-python
pip install pandas
pip install matplotlib
pip install shapely

2. 数据准备

假设你的数据集目录结构如下:

dota/
├── images/
│   ├── train/
│   └── val/
├── labels/
│   ├── train/
│   └── val/

每个图像文件和对应的标签文件都以相同的文件名命名,例如 P0001.pngP0001.txt

3. 创建数据配置文件

创建一个名为dota.yaml的文件,内容如下:

# DOTA Dataset Configuration

# Path to the dataset directory
path: ./dota

# Training and validation image directories
train: images/train
val: images/val

# Training and validation label directories
train_labels: labels/train
val_labels: labels/val

# Number of classes
nc: 14

# Class names
names:
  0: small-vehicle
  1: large-vehicle
  2: plane
  3: storage-tank
  4: ship
  5: harbor
  6: ground-track-field
  7: soccer-ball-field
  8: swimming-pool
  9: helicopter
  10: roundabout
  11: tennis-court
  12: basketball-court
  13: baseball-diamond

4. 训练模型

使用YOLOv8进行目标检测任务训练,可以使用以下命令:

yolo detect train data=dota.yaml model=yolov8n.pt epochs=100 imgsz=512

解释:

  • data=dota.yaml: 指定数据配置文件。
  • model=yolov8n.pt: 使用预训练的YOLOv8检测模型(yolov8n)。你可以选择其他大小的模型,如yolov8syolov8myolov8lyolov8x
  • epochs=100: 训练的轮数。
  • imgsz=512: 图像的尺寸。

5. 评估模型

训练完成后,可以使用以下命令来评估模型在验证集上的性能:

yolo detect val data=dota.yaml model=runs/detect/train/weights/best.pt imgsz=512

解释:

  • data=dota.yaml: 指定数据配置文件。
  • model=runs/detect/train/weights/best.pt: 指定训练过程中保存的最佳模型权重文件。
  • imgsz=512: 图像的尺寸。

6. 可视化预测结果

使用以下Python代码来可视化模型的预测结果:

import cv2
import torch
from ultralytics import YOLO

# 加载模型
model = YOLO('runs/detect/train/weights/best.pt')

# 读取图像
image_path = 'dota/images/val/P0001.png'
image = cv2.imread(image_path)

# 进行预测
results = model(image)

# 可视化预测结果
for result in results:
    boxes = result.boxes.xyxy
    confidences = result.boxes.conf
    class_ids = result.boxes.cls

    for box, conf, class_id in zip(boxes, confidences, class_ids):
        x1, y1, x2, y2 = map(int, box)
        label = model.names[int(class_id)]
        confidence = float(conf)

        # 绘制边界框
        cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
        text = f'{label}: {confidence:.2f}'
        cv2.putText(image, text, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)

# 显示图像
cv2.imshow('Detection Prediction', image)
cv2.waitKey(0)
cv2.destroyAllWindows()

7. 常用的五个脚本

  1. 数据增强脚本:用于生成更多的训练数据。
  2. 数据检查脚本:用于检查数据集的完整性和一致性。
  3. 模型评估脚本:用于评估模型在不同指标上的性能。
  4. 预测结果保存脚本:用于将预测结果保存为文件。
  5. 模型推理脚本:用于在新的图像上进行推理。
1. 数据增强脚本
import os
import cv2
from imgaug import augmenters as iaa

def augment_image(image_path, output_dir, num_augmentations=5):
    image = cv2.imread(image_path)

    seq = iaa.Sequential([
        iaa.Fliplr(0.5),  # 水平翻转
        iaa.Affine(rotate=(-10, 10)),  # 旋转
        iaa.Multiply((0.8, 1.2))  # 改变亮度
    ])

    for i in range(num_augmentations):
        augmented_image = seq(image=image)
        new_image_path = os.path.join(output_dir, f"{os.path.splitext(os.path.basename(image_path))[0]}_aug{i}.png")
        cv2.imwrite(new_image_path, augmented_image)

# 数据集路径
image_dir = 'dota/images/train'
output_dir = 'dota/augmented/train'

os.makedirs(output_dir, exist_ok=True)

# 获取所有图像文件
image_files = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith('.png')]

for image_file in image_files:
    augment_image(image_file, output_dir)
2. 数据检查脚本
import os

def check_data_consistency(image_dir, label_dir):
    image_files = set(os.listdir(image_dir))
    label_files = set(os.listdir(label_dir))

    missing_labels = image_files - {f.replace('.txt', '.png') for f in label_files}
    missing_images = {f.replace('.txt', '.png') for f in label_files} - image_files

    if missing_labels:
        print("Missing labels for images:")
        for file in missing_labels:
            print(file)

    if missing_images:
        print("Missing images for labels:")
        for file in missing_images:
            print(file)

    if not missing_labels and not missing_images:
        print("Data consistency check passed.")

# 数据集路径
image_dir = 'dota/images/train'
label_dir = 'dota/labels/train'

check_data_consistency(image_dir, label_dir)
3. 模型评估脚本
import torch
from ultralytics import YOLO
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval

def evaluate_model(model_path, data_config, ann_file):
    model = YOLO(model_path)
    coco = COCO(ann_file)
    results = []

    for img_id in coco.getImgIds():
        img_info = coco.loadImgs(img_id)[0]
        image_path = os.path.join(data_config['path'], img_info['file_name'])
        image = cv2.imread(image_path)

        predictions = model(image)

        for pred in predictions:
            boxes = pred.boxes.xyxy
            confidences = pred.boxes.conf
            class_ids = pred.boxes.cls

            for box, conf, class_id in zip(boxes, confidences, class_ids):
                x1, y1, x2, y2 = map(int, box)
                w, h = x2 - x1, y2 - y1
                results.append({
                    "image_id": img_id,
                    "category_id": int(class_id) + 1,
                    "bbox": [x1, y1, w, h],
                    "score": float(conf)
                })

    with open('results.json', 'w') as f:
        json.dump(results, f)

    coco_dt = coco.loadRes('results.json')
    coco_eval = COCOeval(coco, coco_dt, 'bbox')
    coco_eval.evaluate()
    coco_eval.accumulate()
    coco_eval.summarize()

# 模型路径和数据配置
model_path = 'runs/detect/train/weights/best.pt'
data_config = {
    'path': './dota',
    'annotations': 'annotations/instances_val.json'
}
ann_file = os.path.join(data_config['path'], data_config['annotations'])

evaluate_model(model_path, data_config, ann_file)
4. 预测结果保存脚本
import cv2
import torch
from ultralytics import YOLO
import json

def save_predictions(model_path, image_dir, output_dir):
    model = YOLO(model_path)
    os.makedirs(output_dir, exist_ok=True)

    image_files = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith('.png')]

    results = []
    for image_file in image_files:
        image = cv2.imread(image_file)
        predictions = model(image)

        for pred in predictions:
            boxes = pred.boxes.xyxy
            confidences = pred.boxes.conf
            class_ids = pred.boxes.cls

            for box, conf, class_id in zip(boxes, confidences, class_ids):
                x1, y1, x2, y2 = map(int, box)
                w, h = x2 - x1, y2 - y1
                results.append({
                    "image_path": image_file,
                    "category_id": int(class_id),
                    "bbox": [x1, y1, w, h],
                    "score": float(conf)
                })

    with open(os.path.join(output_dir, 'predictions.json'), 'w') as f:
        json.dump(results, f)

# 模型路径和数据集路径
model_path = 'runs/detect/train/weights/best.pt'
image_dir = 'dota/images/val'
output_dir = 'dota/predictions'

save_predictions(model_path, image_dir, output_dir)
5. 模型推理脚本
import cv2
import torch
from ultralytics import YOLO

def infer_model(model_path, image_path):
    model = YOLO(model_path)
    image = cv2.imread(image_path)

    results = model(image)

    for result in results:
        boxes = result.boxes.xyxy
        confidences = result.boxes.conf
        class_ids = result.boxes.cls

        for box, conf, class_id in zip(boxes, confidences, class_ids):
            x1, y1, x2, y2 = map(int, box)
            label = model.names[int(class_id)]
            confidence = float(conf)

            # 绘制边界框
            cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
            text = f'{label}: {confidence:.2f}'
            cv2.putText(image, text, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)

    # 显示图像
    cv2.imshow('Inference', image)
    cv2.waitKey(0)
    cv2.destroyAllWindows()

# 模型路径和图像路径
model_path = 'runs/detect/train/weights/best.pt'
image_path = 'dota/images/val/P0001.png'

infer_model(model_path, image_path)

8. 总结

以上步骤提供了一个完整的框架,用于使用YOLOv8训练DOTA数据集的目标检测任务。代码包括数据准备、模型训练、评估和结果可视化等多个步骤。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值