mmdetection实战,训练自己的数据集

1 库安装

        pip install timm==1.0.7 thop efficientnet_pytorch==0.7.1 einops grad-cam==1.4.8 dill==0.3.6 albumentations==1.4.11 pytorch_wavelets==1.3.0 tidecv PyWavelets -i https://pypi.tuna.tsinghua.edu.cn/simple
        pip install -U openmim -i https://pypi.tuna.tsinghua.edu.cn/simple
        mim install mmengine -i https://pypi.tuna.tsinghua.edu.cn/simple
        mim install "mmcv==2.1.0" -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install YOLO
pip install ultralytics

*需要注意mmcv库的兼容性。

2 源码下载

GitHub - open-mmlab/mmdetection at v3.0.0

3 数据集准备

我是yolo数据集转coco

使用时需要修改类别和路径。

import json
import os
import shutil

import cv2

# info ,license,categories 结构初始化;
# 在train.json,val.json,test.json里面信息是一致的;

# info,license暂时用不到
info = {
    "year": 2024,
    "version": '1.0',
    "date_created": 2024 - 6 - 9
}

licenses = {
    "id": 1,
    "name": "null",
    "url": "null",
}

# 自己的标签类别,跟yolo的数据集类别要对应好;
categories = [
    {
        "id": 0,
        "name": 'L',
        "supercategory": 'lines',
    },
    {
        "id": 1,
        "name": 'R',
        "supercategory": 'lines',
    },
    {
        "id": 2,
        "name": 'I',
        "supercategory": 'lines',
    },
    {
        "id": 3,
        "name": 'M',
        "supercategory": 'lines',
    },
    {
        "id": 4,
        "name": 'A',
        "supercategory": 'lines',
    }

]

# 初始化train,test、valid 数据字典
# info licenses categories 在train和test里面都是一致的;
train_data = {'info': info, 'licenses': licenses, 'categories': categories, 'images': [], 'annotations': []}
test_data = {'info': info, 'licenses': licenses, 'categories': categories, 'images': [], 'annotations': []}
valid_data = {'info': info, 'licenses': licenses, 'categories': categories, 'images': [], 'annotations': []}


# image_path 对应yolov8的图像路径,比如images/train;
# label_path 对应yolov8的label路径,比如labels/train 跟images要对应;
def yolo_covert_coco_format(image_path, label_path):
    images = []
    annotations = []
    for index, img_file in enumerate(os.listdir(image_path)):
        if img_file.endswith('.jpg'):
            image_info = {}
            img = cv2.imread(os.path.join(image_path, img_file))
            height, width, channel = img.shape
            image_info['id'] = index
            image_info['file_name'] = img_file
            image_info['width'], image_info['height'] = width, height
        else:
            continue
        if image_info != {}:
            images.append(image_info)
        # 处理label信息-------
        label_file = os.path.join(label_path, img_file.replace('.jpg', '.txt'))
        with open(label_file, 'r') as f:
            for idx, line in enumerate(f.readlines()):
                info_annotation = {}
                class_num, xs, ys, ws, hs = line.strip().split(' ')
                class_id, xc, yc, w, h = int(class_num), float(xs), float(ys), float(ws), float(hs)
                xmin = (xc - w / 2) * width
                ymin = (yc - h / 2) * height
                xmax = (xc + w / 2) * width
                ymax = (yc + h / 2) * height
                bbox_w = int(width * w)
                bbox_h = int(height * h)
                img_copy = img[int(ymin):int(ymax), int(xmin):int(xmax)].copy()

                info_annotation["category_id"] = class_id  # 类别的id
                info_annotation['bbox'] = [xmin, ymin, bbox_w, bbox_h]  ## bbox的坐标
                info_annotation['area'] = bbox_h * bbox_w  ###area
                info_annotation['image_id'] = index  # bbox的id
                info_annotation['id'] = index * 100 + idx  # bbox的id
                # cv2.imwrite(f"./temp/{info_annotation['id']}.jpg", img_copy)
                info_annotation['segmentation'] = [[xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax]]  # 四个点的坐标
                info_annotation['iscrowd'] = 0  # 单例
                annotations.append(info_annotation)
    return images, annotations


# key == train,test,val
# 对应要生成的json文件,比如instances_train.json,instances_test.json,instances_val.json
# 只是为了不重复写代码。。。。。
def gen_json_file(yolov8_data_path, coco_format_path, key):
    # json path
    json_path = os.path.join(coco_format_path, f'annotations/instances_{key}.json')
    dst_path = os.path.join(coco_format_path, f'{key}')
    if not os.path.exists(os.path.dirname(json_path)):
        os.makedirs(os.path.dirname(json_path), exist_ok=True)
    data_path = os.path.join(yolov8_data_path, f'images/{key}')
    label_path = os.path.join(yolov8_data_path, f'labels/{key}')
    images, anns = yolo_covert_coco_format(data_path, label_path)
    if key == 'train':
        train_data['images'] = images
        train_data['annotations'] = anns
        with open(json_path, 'w') as f:
            json.dump(train_data, f, indent=2)
        # shutil.copy(data_path,'')
    elif key == 'test':
        test_data['images'] = images
        test_data['annotations'] = anns
        with open(json_path, 'w') as f:
            json.dump(test_data, f, indent=2)
    elif key == 'val':
        valid_data['images'] = images
        valid_data['annotations'] = anns
        with open(json_path, 'w') as f:
            json.dump(valid_data, f, indent=2)
    else:
        print(f'key is {key}')
    print(f'generate {key} json success!')
    return


if __name__ == '__main__':
    yolov8_data_path = r''
    coco_format_path = r''
    gen_json_file(yolov8_data_path, coco_format_path, key='train')
    gen_json_file(yolov8_data_path, coco_format_path, key='val')
    gen_json_file(yolov8_data_path, coco_format_path, key='test')

coco标签可视化工具,可以用于检查数据集是否有问题


import matplotlib.pyplot as plt
import matplotlib.patches as patches
from pycocotools.coco import COCO
import numpy as np
import skimage.io as io
import os
# 定义COCO数据集的路径
dataDir = r''
dataType = 'train2017'
annFile = f'{dataDir}/annotations/instances_{dataType}.json'

# 初始化COCO API
coco = COCO(annFile)

# 获取一张图片的ID
imgIds = coco.getImgIds()
img = coco.loadImgs(imgIds[np.random.randint(0, len(imgIds))])[0]

# 加载并显示图片
I = io.imread(f'{dataDir}/{dataType}/{img["file_name"]}')
plt.imshow(I)
plt.axis('off')

# 获取图片中的标注
annIds = coco.getAnnIds(imgIds=img['id'], iscrowd=None)
anns = coco.loadAnns(annIds)

# 显示标注和类别标签
for ann in anns:
    bbox = ann['bbox']
    category_id = ann['category_id']
    category_name = coco.loadCats(category_id)[0]['name']  # 获取类别名称

    # 画出边框
    rect = patches.Rectangle((bbox[0], bbox[1]), bbox[2], bbox[3], linewidth=2, edgecolor='r', facecolor='none')
    plt.gca().add_patch(rect)

    # 在边框上方显示类别名称
    plt.text(bbox[0], bbox[1] - 10, category_name, color='yellow', fontsize=12, weight='bold', backgroundcolor='black')

# 确保输出目录存在

output_path = f'./{img["file_name"]}_annotated.jpg'
plt.savefig(output_path, bbox_inches='tight', pad_inches=0.1)
plt.show()

4 代码修改

命令行进入代码文件夹,然后输入指令

pip install -v -e.

根据你想要跑什么模型,以faster rcnn为例。

你需要修改faster-rcnn_r50_fpn_1x.py ,classes_name.py ,coco.py, coco_detection.py 和shedule_1x.py中的数据集参数和训练参数。例如类别名,类别数,数据集路径,epoches和batch等。

这些文件的路径

mmdetection-3.0.0/configs/_base_/schedules/schedule_1x.py
mmdetection-3.0.0/configs/_base_/datasets/coco_detection.py
mmdetection-3.0.0/configs/_base_/models/faster-rcnn_r50_fpn.py
mmdetection-3.0.0/mmdet/datasets/coco.py

5 运行

python tools/train.py  configs/_base_/models/faster-rcnn_r50_fpn.py

或 

python tools/train.py work_dirs/faster-rcnn_r50_fpn_1x_coco/faster-rcnn_r50_fpn_1x_coco.py

后续若需要修改参数可以直接在faster-rcnn_r50_fpn_1x_coco.py中修改。

6 测试

python tools/test.py work_dirs/faster-rcnn_r50_fpn_1x_coco/faster-rcnn_r50_fpn_1x_coco.py work_dirs/faster-rcnn_r50_fpn_1x_coco/epoch_20.pth --show-dir out

***etection 3.0 训练自己的数据集,您需要执行以下步骤: 1. 数据准备:准备好您自己的数据集,并确保数据集的目录结构与 COCO 数据集相似。数据集应包含标注文件(如 JSON 或 XML)和图像文件。 2. 数据转换:将您的数据集转换为 COCO 格式。您可以使用工具如 labelme、VIA 等来标注和转换数据。 3. 配置模型:在 mmdetection 3.0 的配置文件中选择适合您任务的模型,例如 Faster R-CNN、Mask R-CNN 等。您可以在 mmdetection 的模型库中找到相关的配置文件,并根据您的需求进行修改。 4. 修改配置文件:打开所选模型的配置文件,根据您的数据集训练需求进行相应修改。主要包括类别数目、数据集路径、训练和测试的批量大小、学习率等参数。 5. 训练模型:使用命令行运行训练脚本,指定配置文件和 GPU 数量。例如,使用以下命令启动训练: ```shell python tools/train.py <CONFIG_FILE> --gpus <NUM_GPUS> ``` 其中 `<CONFIG_FILE>` 是您修改后的配置文件路径,`<NUM_GPUS>` 是用于训练的 GPU 数量。 6. 测试模型:在训练过程中,您可以使用验证集来监控模型的性能。当训练完成后,您可以使用以下命令进行模型测试: ```shell python tools/test.py <CONFIG_FILE> <CHECKPOINT_FILE> --eval <EVAL_METRICS> ``` 其中 `<CONFIG_FILE>` 是您修改后的配置文件路径,`<CHECKPOINT_FILE>` 是训练过程中保存的模型权重文件路径,`<EVAL_METRICS>` 是评估指标,如 bbox、segm 等。 7. 推理模型:使用训练好的模型对新的图像进行目标检测。您可以使用以下命令进行推理: ```shell python tools/infer.py <CONFIG_FILE> <CHECKPOINT_FILE> <IMAGE_FILE> --show ``` 其中 `<CONFIG_FILE>` 是您修改后的配置文件路径,`<CHECKPOINT_FILE>` 是训练过程中保存的模型权重文件路径,`<IMAGE_FILE>` 是要进行推理的图像文件路径。 以上是 mmdetection 3.0 训练自己数据集的基本步骤。在实际操作中,您可能需要根据您的具体需求进行进一步的调试和优化。希望对您有所帮助!如有其他问题,请随时提问。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值