【语义分割】使用MMSegmentation训练VOC格式自定义数据集

1.自定义数据集(VOC格式)的准备

1.1. 将数据集按照以下路径存放到MMSeg目录
mmsegmentation
|--data
|  |--VOCdevkit
|     |--VOC2012
|        |--Annotations
|           |-- 01.xml
|           |-- 02.xml
|
|        |--ImageSets
|           |-- Main
|           |-- Segmentation
|
|        |-- JPEGImages
|           |-- 01.jpg
|           |-- 02.jpg
|
|        |-- SegmentationClass
|           |-- 01.png
|           |-- 02.png|
1.2. 选择模型和预训练权重
1.3. 修改数据集配置文件

查看所选模型的配置文件,以训练DeepLabV3+为例:

(1) 修改数据集信息为:

'../_base_/datasets/pascal_voc12.py'

(2) 进入模型根目录文件修改classes数:configs/_base_/models/deeplabv3_r50-d8.py

(3) 进入voc数据集配置文件修改数据集类别信息:

(4) 进入训练参数配置文件修改预训练模型:configs/_base_/default_runtime.py

1.4. 每次修改完config文件需进行setup
python setup.py install

2. 模型训练

python tools/train.py config.py

以训练DeepLabV3+为例:

python tools/train.py configs/deeplabv3/deeplabv3_r50-d8_4xb4-40k_voc12aug-512x512.py

3. 模型图片批量推理

import os
import torch
import cv2
import argparse
import numpy as np
from pprint import pprint
from tqdm import tqdm
from mmseg.apis import init_model, inference_model

DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
IMAGE_FILE_PATH = r"data/VOCdevkit/VOC2012/JPEGImages"
CONFIG = r'work_dirs/deeplabv3_r50-d8_4xb4-40k_voc12aug-512x512/deeplabv3_r50-d8_4xb4-40k_voc12aug-512x512.py'
CHECKPOINT = r'work_dirs/deeplabv3_r50-d8_4xb4-40k_voc12aug-512x512/iter_40000.pth'
SAVE_DIR = r"work_dir\infer_results"


def parse_args():
    parser = argparse.ArgumentParser(description='Visualize CAM')
    parser.add_argument('--img', default=IMAGE_FILE_PATH, help='Image file')
    parser.add_argument('--config', default=CONFIG, help='Config file')
    parser.add_argument('--checkpoint', default=CHECKPOINT, help='Checkpoint file')
    parser.add_argument('--device', default=DEVICE, help='device')
    parser.add_argument('--save_dir', default=SAVE_DIR, help='save_dir')

    args = parser.parse_args()
    return args


def make_full_path(root_list, root_path):
    file_full_path_list = []
    for filename in root_list:
        file_full_path = os.path.join(root_path, filename)
        file_full_path_list.append(file_full_path)
    return file_full_path_list


def read_filepath(root):
    from natsort import natsorted
    test_image_list = natsorted(os.listdir(root))
    test_image_full_path_list = make_full_path(test_image_list, root)
    return test_image_full_path_list


def apply_colormap(pred_mask):
    color_map = {
        0: [0, 0, 0],  # black: background
        1: [0, 255, 255],  # yellow: careful
        2: [255, 0, 0],  # blue: passable
        3: [0, 0, 255]  # red: impassable
    }

    rgb_image = np.zeros((*pred_mask.shape, 3), dtype=np.uint8)

    for key in color_map:
        rgb_image[pred_mask == key] = color_map[key]

    return rgb_image


def main():
    args = parse_args()

    model_mmseg = init_model(args.config, args.checkpoint, device=args.device)

    for imgs in tqdm(read_filepath(args.img)):
        result = inference_model(model_mmseg, imgs)
        pred_mask = result.pred_sem_seg.data.squeeze(0).detach().cpu().numpy().astype(np.uint8)

        colored_mask = apply_colormap(pred_mask)

        save_path = os.path.join(args.save_dir, f"{os.path.basename(args.config).split('.')[0]}")
        if not os.path.exists(save_path):
            os.makedirs(save_path)

        cv2.imwrite(os.path.join(save_path, f"{os.path.basename(result.img_path).split('.')[0]}.png"), colored_mask,
                    [cv2.IMWRITE_PNG_COMPRESSION, 0])


if __name__ == '__main__':
    main()

MMLab相关指令:

【目标检测】MMDetection常用指令_model.show_result_ericdiii的博客-CSDN博客文章浏览阅读1.6k次,点赞2次,收藏5次。目标检测MMDetection库常用指令_model.show_resulthttps://blog.csdn.net/ericdiii/article/details/125722277?spm=1001.2014.3001.5502【目标检测】使用MMDetection训练自定义COCO格式数据集_mmdetection训练自己的coco数据集-CSDN博客文章浏览阅读1.9k次,点赞3次,收藏19次。使用MMDetection训练自定义COCO格式数据集_mmdetection训练自己的coco数据集https://blog.csdn.net/ericdiii/article/details/126812948?spm=1001.2014.3001.5501

  • 0
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值