MMSegmentation自定义数据集

📕前言

该文章主要是简述一下自己为了完成极市平台赛事过程中,使用 MMSegmentation 语义分割开源库的心得。

在学习一个新的工具之前,一定需要明白自己是用工具实现什么目标,而不是为了学工具而学,一旦有了目的会给你所作的事情带来意义,但是也要避免急于求成(人总是喜欢简单直接的事情,但是只有真正拉扯过肌肉才会成长),所以坚持不下去的时候,只要明白这是你的大脑退缩了,但你仍然想学。💪

\quad

🌳文章结构

本文章将从一下几个方面介绍如何上手 MMsegmentation,并用 MMDeploy 实现简单的部署:

  1. 安装 MMSegmentation
  2. MMSegmentation 的文件结构
  3. MMSegmentation 的配置文件(核心)
  4. 如何在 MMSegmentation 中自定义数据集
  5. 训练和测试

我强烈建议配合官方文档一起学习:https://mmsegmentation.readthedocs.io/zh_CN/latest/index.html
PS:如此良心的开源库还带中文文档!😭

\quad

📝正文

安装 MMSegmentation

环境准备(可选,但推荐)

一般我们为了环境隔离用 Miniconda(Anaconda) 创建一个新的 python 环境,但在某些情况下也可以不用,取决于你的习惯。

官方网站下载并安装 Miniconda & 创建一个 conda 环境,并激活:

conda create --name openmmlab python=3.8 -y
conda activate openmmlab

\quad

安装库
  1. 根据官网安装 pytorch,现在更新到2.0了,但是推荐安装之前的版本(可以点击页面中下面红框的链接,授之以渔),也可以直接点击 install previous versions of PyTorch(授之以鱼)

    gpu 版本(要对应自己的 cuda 版本,pip和conda 二选一)

# pip 安装
# CUDA 11.1 
pip install torch==1.10.1+cu111 torchvision==0.11.2+cu111 torchaudio==0.10.1 -f https://download.pytorch.org/whl/cu111/torch_stable.html

# 或者
# conda 安装
# CUDA 11.3
conda install pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 cudatoolkit=11.3 -c pytorch -c conda-forge 

cpu 版本(看MMSegmentation的官方文档吧)

![image.png](/img/bVc7tU5)

\quad

  1. 安装 MMCV(OpenMMLab 其他许多库都有这个依赖)
    推荐安装方式 mim,更多方式看 MMCV
pip install -U openmim
mim install mmengine
mim install "mmcv>=2.0.0"

\quad

  1. 安装 MMsegmentation
    a. 方式一:源码安装,这个比较容易后期开发,因为能够直接修改并使用源码(本教程安装方式)

    git clone -b main https://github.com/open-mmlab/mmsegmentation.git
    cd mmsegmentation
    pip install -v -e .
    # '-v' 表示详细模式,更多的输出
    # '-e' 表示以可编辑模式安装工程,
    # 因此对代码所做的任何修改都生效,无需重新安装
    

    b. 方式二:作为依赖库安装

    pip install "mmsegmentation>=1.0.0"
    

\quad

验证安装是否成功

源码安装检验方式

cd mmsegmentation
python demo/image_demo.py demo/demo.png \\
configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py \\
pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth \\
--device cuda:0 --out-file result.jpg

您将在当前文件夹中看到一个新图像 result.jpg,其中所有目标都覆盖了分割 mask

其他更多安装方式见官方文档:https://mmsegmentation.readthedocs.io/zh_CN/latest/get_started.html#mmseg

\quad

MMSegmentation 的文件结构

接下来我们稍微看一下 MMsegmentation 的文件结构目录

mmsegmentation
- configs # **配置文件,是该库的核心**
	- _base_ # 基础模块文件,**但本质上还是配置文件**,包括数据集,模型,训练配置
		- datasets
		- models
		- schedules	
	- else model config # 除了 _base_ 之外,其他都是通过利用 _base_ 中定义好的模块进行组合的模型文件


- mmseg # **这是库核心的实现,上面配置文件的模块都在这里定义**
	- datasets
	- models

- tools # 这里包括训练、测试、转onnx等写好了的工具,直接调用即可
	- train.py
	- test.py

- data # 放置数据集

- demo # 提供了几个小 demo(可不管)
- docker # 容器配置(可不管)
- docs # 各种说明文档(可不管)

- projects # (可不管)
- requirements # (可不管)
- tests # (可不管)

从上面可以看出,其实 MMSegmentation 做了很好的封装,如果只是使用,那是非常容易上手的。

config/base 和 mmseg 中的 datasets、models等文件有什么区别呢?
下面用 ade 数据集举一个例子(大致看一下差异,不需要弄懂):

  • config/_base_/datasets/ade20k.py
# dataset settings
dataset_type = 'ADE20KDataset'
data_root = 'data/ade/ADEChallengeData2016'
crop_size = (512, 512)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', reduce_zero_label=True),
    dict(
        type='RandomResize',
        scale=(2048, 512),
        ratio_range=(0.5, 2.0),
        keep_ratio=True),
    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
    dict(type='RandomFlip', prob=0.5),
    dict(type='PhotoMetricDistortion'),
    dict(type='PackSegInputs')
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='Resize', scale=(2048, 512), keep_ratio=True),
    # add loading annotation after ``Resize`` because ground truth
    # does not need to do resize data transform
    dict(type='LoadAnnotations', reduce_zero_label=True),
    dict(type='PackSegInputs')
]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
    dict(type='LoadImageFromFile', backend_args=None),
    dict(
        type='TestTimeAug',
        transforms=[
            [
                dict(type='Resize', scale_factor=r, keep_ratio=True)
                for r in img_ratios
            ],
            [
                dict(type='RandomFlip', prob=0., direction='horizontal'),
                dict(type='RandomFlip', prob=1., direction='horizontal')
            ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
        ])
]
train_dataloader = dict(
    batch_size=4,
    num_workers=4,
    persistent_workers=True,
    sampler=dict(type='InfiniteSampler', shuffle=True),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        data_prefix=dict(
            img_path='images/training', seg_map_path='annotations/training'),
        pipeline=train_pipeline))
val_dataloader = dict(
    batch_size=1,
    num_workers=4,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        data_prefix=dict(
            img_path='images/validation',
            seg_map_path='annotations/validation'),
        pipeline=test_pipeline))
test_dataloader = val_dataloader

val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
test_evaluator = val_evaluator

  • mmseg/datasets/ade.py
# Copyright (c) OpenMMLab. All rights reserved.
from mmseg.registry import DATASETS
from .basesegdataset import BaseSegDataset


@DATASETS.register_module()
class ADE20KDataset(BaseSegDataset):
    """ADE20K dataset.

    In segmentation map annotation for ADE20K, 0 stands for background, which
    is not included in 150 categories. ``reduce_zero_label`` is fixed to True.
    The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to
    '.png'.
    """
    METAINFO = dict(
        classes=('wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road',
                 'bed ', 'windowpane', 'grass', 'cabinet', 'sidewalk',
                 'person', 'earth', 'door', 'table', 'mountain', 'plant',
                 'curtain', 'chair', 'car', 'water', 'painting', 'sofa',
                 'shelf', 'house', 'sea', 'mirror', 'rug', 'field', 'armchair',
                 'seat', 'fence', 'desk', 'rock', 'wardrobe', 'lamp',
                 'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
                 'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
                 'skyscraper', 'fireplace', 'refrigerator', 'grandstand',
                 'path', 'stairs', 'runway', 'case', 'pool table', 'pillow',
                 'screen door', 'stairway', 'river', 'bridge', 'bookcase',
                 'blind', 'coffee table', 'toilet', 'flower', 'book', 'hill',
                 'bench', 'countertop', 'stove', 'palm', 'kitchen island',
                 'computer', 'swivel chair', 'boat', 'bar', 'arcade machine',
                 'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
                 'chandelier', 'awning', 'streetlight', 'booth',
                 'television receiver', 'airplane', 'dirt track', 'apparel',
                 'pole', 'land', 'bannister', 'escalator', 'ottoman', 'bottle',
                 'buffet', 'poster', 'stage', 'van', 'ship', 'fountain',
                 'conveyer belt', 'canopy', 'washer', 'plaything',
                 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall',
                 'tent', 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food',
                 'step', 'tank', 'trade name', 'microwave', 'pot', 'animal',
                 'bicycle', 'lake', 'dishwasher', 'screen', 'blanket',
                 'sculpture', 'hood', 'sconce', 'vase', 'traffic light',
                 'tray', 'ashcan', 'fan', 'pier', 'crt screen', 'plate',
                 'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
                 'clock', 'flag'),
        palette=[[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
                 [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
                 [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
                 [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
                 [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
                 [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
                 [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
                 [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
                 [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
                 [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
                 [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
                 [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
                 [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
                 [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
                 [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
                 [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
                 [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
                 [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
                 [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
                 [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
                 [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
                 [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
                 [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
                 [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
                 [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
                 [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
                 [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
                 [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
                 [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
                 [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
                 [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
                 [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
                 [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
                 [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
                 [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
                 [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
                 [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
                 [102, 255, 0], [92, 0, 255]])

    def __init__(self,
                 img_suffix='.jpg',
                 seg_map_suffix='.png',
                 reduce_zero_label=True,
                 **kwargs) -> None:
        super().__init__(
            img_suffix=img_suffix,
            seg_map_suffix=seg_map_suffix,
            reduce_zero_label=reduce_zero_label,
            **kwargs)

\quad

MMSegmentation 的 config 配置文件 (核心)

在使用 MMSegmentation 中的模型进行训练和测试的时候就能够看出 config 配置文件的重要性

在单GPU上训练和测试
在单GPU上训练

tools/train.py 文件提供了在单GPU上部署训练任务的方法。

基础用法如下:

python tools/train.py  ${配置文件} [可选参数]
# 关键参数:
#	config.py # 必须提供撇脂文件
# 	--work-dir ${工作路径} # 重新指定工作路径

更多其他参数详情

举例 pspnet

python tools/train.py \\
configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py \\
--work-dir logs/pspnet

configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py
该配置文件调用了_base_中定义的 models、dataset、schedules等配置文件,这种模块化方式就很容易通过重新组合来调整整体模型。

_base_ = [
    '../_base_/models/pspnet_r50-d8.py', '../_base_/datasets/cityscapes.py',
    '../_base_/default_runtime.py', '../_base_/schedules/schedule_40k.py'
]
crop_size = (512, 1024)
data_preprocessor = dict(size=crop_size)
model = dict(data_preprocessor=data_preprocessor)

其中每个模块的配置文件细节见:https://mmsegmentation.readthedocs.io/zh_CN/latest/user_guides/1_config.html#pspnet

\quad

如何在 MMSegmentation 中自定义数据集

这应该是大家比较关心的部分,重点是。我们首先看看官方对于一些常用的数据集的文件目录是怎么样的(拿 CHASE_DB1 数据集(二类别语义分割)举个例子):

mmsegmentation
├── mmseg
├── tools
├── configs
├── data
│   ├── CHASE_DB1
│   │   ├── images
│   │   │   ├── training
│   │   │   ├── validation
│   │   ├── annotations
│   │   │   ├── training
│   │   │   ├── validation

可见其中包含:

  • annotations:语义分割的真实 mark label
  • images:待分割的RGB图像
自定义数据集

根据以上结构我们可以构建自己的数据集,这里我主要是利用极市平台写字楼消防门堵塞识别二类别语义分割任务的数据集,其中门的label是1,背景label是0
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-4YwojjrW-1682734064668)(/img/bVc7t5o)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-poPjBL2A-1682734064669)(/img/bVc7t5p)]

并且将其划分为训练集和验证集,在 mmsegmentation/data 中添加以下文件:

mmsegmentation
|   data
|   | xiaofang
│   │   ├── images
│   │   │   ├── training
│   │   │   ├── validation
│   │   ├── annotations
│   │   │   ├── training
│   │   │   ├── validation
添加数据集模块
  1. mmsegmentation/mmseg/datasets 中添加一个 xiaofang.py 定义自己的数据类 XiaoFangDataset
    xiaofang.py
    # Copyright (c) OpenMMLab. All rights reserved.
    
    from .builder import DATASETS
    from .custom import CustomDataset
    
    
    @DATASETS.register_module()
    class XiaoFangDataset(CustomDataset):
        CLASSES = ('background', 'door')
    
        PALETTE = [[120, 120, 120], [6, 230, 230]]
    
        def __init__(self, **kwargs):
            super(XiaoFangDataset, self).__init__(
                img_suffix='.jpg', # 注意路径
                seg_map_suffix='.png',
                reduce_zero_label=False,
                **kwargs)
            assert self.file_client.exists(self.img_dir)
    
    
  2. mmsegmentation/mmseg/datasets/__init__.py 中声明自己定义的数据类XiaoFangDataset
    # Copyright (c) OpenMMLab. All rights reserved.
    from .ade import ADE20KDataset
    from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset
    from .chase_db1 import ChaseDB1Dataset
    from .cityscapes import CityscapesDataset
    from .coco_stuff import COCOStuffDataset
    from .custom import CustomDataset
    from .dark_zurich import DarkZurichDataset
    from .dataset_wrappers import (ConcatDataset, MultiImageMixDataset,
                                   RepeatDataset)
    from .drive import DRIVEDataset
    from .face import FaceOccludedDataset
    from .hrf import HRFDataset
    from .isaid import iSAIDDataset
    from .isprs import ISPRSDataset
    from .loveda import LoveDADataset
    from .night_driving import NightDrivingDataset
    from .pascal_context import PascalContextDataset, PascalContextDataset59
    from .potsdam import PotsdamDataset
    from .stare import STAREDataset
    from .voc import PascalVOCDataset
    from .xiaofang import XiaoFangDataset
    
    __all__ = [
        'CustomDataset', 'build_dataloader', 'ConcatDataset', 'RepeatDataset',
        'DATASETS', 'build_dataset', 'PIPELINES', 'CityscapesDataset',
        'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset',
        'PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset',
        'STAREDataset', 'DarkZurichDataset', 'NightDrivingDataset',
        'COCOStuffDataset', 'LoveDADataset', 'MultiImageMixDataset',
        'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset', 'FaceOccludedDataset',
        'XiaoFangDataset'
    ]
    
    
  3. mmsegmentation/mmseg/core/evaluation/class_names.py 中声明自己的标签类别名称
    def xiaofang_classes():
        return [
            'background','door'
        ]
    
  4. mmsegmentation/configs/_base_/datasets 中添加自己数据集的配置文件 xiaofang.py
    # dataset settings
    dataset_type = 'XiaoFangDataset' # 数据类名称
    data_root = 'data/xiaofang' # 数据存放位置
    img_norm_cfg = dict(
        mean=[120.4652, 123.1624, 124.3220], std=[63.5322, 60.6218, 59.2707], to_rgb=True)
    crop_size = (512, 512)
    train_pipeline = [
        dict(type='LoadImageFromFile'),
        dict(type='LoadAnnotations'),
        dict(type='Resize', img_scale=(1920, 1080), ratio_range=(0.5, 2.0)),
        dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
        dict(type='RandomFlip', prob=0.5),
        dict(type='PhotoMetricDistortion'),
        dict(type='Normalize', **img_norm_cfg),
        dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
        dict(type='DefaultFormatBundle'),
        dict(type='Collect', keys=['img', 'gt_semantic_seg']),
    ]
    test_pipeline = [
        dict(type='LoadImageFromFile'),
        dict(
            type='MultiScaleFlipAug',
            # img_scale=(2048, 512),
            img_scale=(1920, 1080),
            # img_scale=(960, 540),
            # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
            flip=False,
            transforms=[
                dict(type='Resize', keep_ratio=True),
                dict(type='RandomFlip'),
                dict(type='Normalize', **img_norm_cfg),
                dict(type='ImageToTensor', keys=['img']),
                dict(type='Collect', keys=['img']),
            ])
    ]
    data = dict(
        samples_per_gpu=4,
        workers_per_gpu=4,
        train=dict(
            type=dataset_type,
            data_root=data_root,
            img_dir='images/training',
            ann_dir='annotations/training',
            pipeline=train_pipeline),
        val=dict(
            type=dataset_type,
            data_root=data_root,
            img_dir='images/validation',
            ann_dir='annotations/validation',
            pipeline=test_pipeline),
        test=dict(
            type=dataset_type,
            data_root=data_root,
            img_dir='images/validation',
            ann_dir='annotations/validation',
            pipeline=test_pipeline))
    
    

其中配置文件参数的细节含义仍见:https://mmsegmentation.readthedocs.io/zh_CN/latest/user_guides/1_config.html#pspnet

\quad

训练和测试

在完成了数据集配置后,就需要搭建整体模型的配置文件即可,MMSegmentation 提供了许多开源模型(下面是一部分,更多详情):
在这里插入图片描述

一般需要根据自己的GPU显存大小选择模型,点击上面的 config 能够看到对应模型所需要的显存大小,如这里我们举例选择一个 STDC 模型
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-JfIEcuv1-1682734064670)(/img/bVc7t68)]

  1. 修改完整配置文件:在 mmsegmentation/configs/stdc 中添加上自己的模型 stdc2_512x1024_10k_xiaofang.py

    _base_ = ['../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py', '../_base_/datasets/xiaofang.py']
    
    # checkpoint = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/stdc/stdc1_20220308-5368626c.pth'  # noqa
    
    
    norm_cfg = dict(type='BN', requires_grad=True)
    model = dict(
        type='EncoderDecoder',
        pretrained=None,
        backbone=dict(
            type='STDCContextPathNet',
            backbone_cfg=dict(
                # init_cfg=dict(type='Pretrained', checkpoint=checkpoint),
                type='STDCNet',
                stdc_type='STDCNet2',
                in_channels=3,
                channels=(32, 64, 256, 512, 1024),
                bottleneck_type='cat',
                num_convs=4,
                norm_cfg=norm_cfg,
                act_cfg=dict(type='ReLU'),
                with_final_conv=False),
            last_in_channels=(1024, 512),
            out_channels=128,
            ffm_cfg=dict(in_channels=384, out_channels=256, scale_factor=4)),
        decode_head=dict(
            type='FCNHead',
            in_channels=256,
            channels=256,
            num_convs=1,
            num_classes=2,
            in_index=3,
            concat_input=False,
            dropout_ratio=0.1,
            norm_cfg=norm_cfg,
            align_corners=False,
            sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000),
            loss_decode=dict(
                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
        auxiliary_head=[
            dict(
                type='FCNHead',
                in_channels=128,
                channels=64,
                num_convs=1,
                num_classes=2,
                in_index=2,
                norm_cfg=norm_cfg,
                concat_input=False,
                align_corners=False,
                sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000),
                loss_decode=dict(
                    type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
            dict(
                type='FCNHead',
                in_channels=128,
                channels=64,
                num_convs=1,
                num_classes=2,
                in_index=1,
                norm_cfg=norm_cfg,
                concat_input=False,
                align_corners=False,
                sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000),
                loss_decode=dict(
                    type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
            dict(
                type='STDCHead',
                in_channels=256,
                channels=64,
                num_convs=1,
                num_classes=2,
                boundary_threshold=0.1,
                in_index=0,
                norm_cfg=norm_cfg,
                concat_input=False,
                align_corners=False,
                loss_decode=[
                    dict(
                        type='CrossEntropyLoss',
                        loss_name='loss_ce',
                        use_sigmoid=True,
                        loss_weight=1.0),
                    dict(type='DiceLoss', loss_name='loss_dice', loss_weight=1.0)
                ]),
        ],
        # model training and testing settings
        train_cfg=dict(),
        test_cfg=dict(mode='whole'))
    
    
    checkpoint_config = dict(  # 设置检查点钩子 (checkpoint hook) 的配置文件。执行时请参考 https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/checkpoint.py。
        by_epoch=False,
        save_last=False,  # 是否按照每个 epoch 去算 runner。
        interval=2000)  # 保存的间隔
    
    evaluation = dict(interval=1000, metric='mIoU', pre_eval=True)
    runner = dict(type='IterBasedRunner', max_iters=10000)
    log_config = dict(
        interval=10,
        hooks=[
            dict(type='TextLoggerHook', by_epoch=False),
            # dict(type='TensorboardLoggerHook')
            # dict(type='PaviLoggerHook') # for internal services
        ])
    lr_config = dict(warmup='linear', warmup_iters=1000)
    
  2. 训练

    python tools/train.py \\
    configs/stdc/stdc2_512x1024_10k_xiaofang.py \\
    --work-dir logs/stdc2
    
  3. 测试结果:MIoU=0.9225,下面分别是RGB图像、真实Label、STDC模型输出
    请添加图片描述

  • 4
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
### 回答1: mmsegmentation是一个基于PyTorch的开源图像分割工具箱,可以用于训练自己的数据集。以下是训练自己数据集的步骤: 1. 准备数据集:将数据集按照训练集、验证集和测试集划分,并将其转换为mmsegmentation所需的格式。 2. 配置训练参数:在mmsegmentation中,训练参数可以通过配置文件进行设置,包括模型、优化器、学习率、损失函数等。 3. 开始训练:使用mmseg的命令行工具开始训练模型,可以通过设置参数来控制训练过程。 4. 评估模型:训练完成后,可以使用mmseg的命令行工具对模型进行评估,包括计算IoU、mIoU等指标。 5. 模型预测:使用训练好的模型对新的图像进行分割预测。 需要注意的是,训练自己的数据集需要一定的计算资源和时间,同时需要对数据集进行充分的预处理和清洗,以提高模型的训练效果。 ### 回答2: mmsegmentation 是一个用于图像分割的深度学习框架,它基于 PyTorch 框架,已经被广泛应用于图像语义分割、实例分割、阴影检测等任务。其所支持的数据类型包括常用的数据集,如 PASCAL VOC、ADE20K、COCO 等。而对于我们自己的数据集,也可以通过一系列步骤来应用于 mmsegmentation 中。 首先,在准备数据时,需要将自己的数据集转化为 mmsegmentation 所支持的数据格式。具体来说,需要将数据集的图片分成训练集、验证集和测试集,同时生成一个 JSON 格式的标注文件,以供训练和测试时使用。同时,还需要对数据进行增强处理,包括大小缩放、翻转、剪裁等等。 其次,在定义模型时,需要根据自己的数据类型选择适合的模型和损失函数。这些模型和损失函数已经在 mmsegmentation 中预定义好了,同时也可以自行定义自己的模型和损失函数。例如,对于常用的图像分割任务,可以使用常见的网络模型,如 UNet、PSPNet 等。 最后,使用 mmsegmentation 进行训练和测试时,需要进行一些参数的配置。主要包括训练参数和测试参数两部分。训练参数包括训练数据集、验证数据集、批量大小、学习率、学习率策略、优化算法等等。测试参数包括测试数据集、模型路径等等。 总体而言,mmsegmentation 是一个非常灵活和易于使用的工具,我们可以使用它来训练和测试自己的数据集。同时,通过不断地调整和优化参数,我们可以得到更加准确的分割结果。 ### 回答3: mmsegmentation是一个基于PyTorch框架的图像分割工具包,可以用来实现各种图像分割算法,如FCN、U-Net、DeepLab、Mask R-CNN等。mmsegmentation提供了训练和测试的代码和模型,也支持自定义数据集的训练。 下面我们将重点介绍mmsegmentation训练自己的数据集: 1. 数据集准备 在训练之前,需要准备好一个包含训练、验证和测试图像以及它们的标注的数据集数据集应该按照一定的文件结构进行组织,比如: ``` + dataset + train - image_1.jpg - image_1.png - ... + val - image_1.jpg - image_1.png - ... + test - image_1.jpg - ... ``` 其中,“train”目录包含训练图像和它们的标注,“val”目录包含验证图像和它们的标注,“test”目录包含测试图像。图像文件可以是jpg、png等格式,标注文件可以是png、mat等格式。注意,标注文件应该和图像文件保持对应,且标注像素的取值通常为0、1、2、...、n-1,表示不同的目标类别。 2. 数据集注册 注册自己的数据集需要通过继承mmcv的Dataset类来实现。自定义数据集需要实现少量方法,包括: * \_\_init\_\_:初始化方法,包括定义类别列表、文件列表等。 * \_\_len\_\_:返回数据集中样本数量。 * \_\_getitem\_\_:返回数据集中指定下标的一条数据和标注。 需要注意的是,返回的数据应该按照mmcv的格式进行处理,比如将图像和标注分别转成ndarray格式并归一化后返回。 3. 配置模型 mmsegmentation支持的模型我们可以通过它的配置文件来配置。通过制定不同的配置文件,我们可以配置不同的网络模型、优化器、学习率策略、训练参数等。对于自己的数据集,我们需要在配置文件中指定类别数、输入图像大小等相关参数。 选择具体的网络模型需要根据自己的数据集大小选择。如果数据集较小,我们可以选择较小的模型,否则可以考虑选择较大的模型,如DeepLabV3+、FCN等。 4. 开始训练 当数据集注册和模型配置完成后,我们可以开始训练自己的数据集。可以通过mmseg中提供的工具进行训练,比如: ``` python tools/train.py ${CONFIG_FILE} ``` 其中,${CONFIG_FILE}是指定的配置文件路径。训练过程中可以通过设置检查点、学习率、优化器等参数来调整模型的训练效果。 5. 验证和测试 训练完成后,我们可以通过mmseg提供的工具进行模型验证和测试,比如: ``` # 验证 python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} --eval mIoU # 测试 python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} --out result.pkl ``` 其中,${CHECKPOINT_FILE}是训练过程中保存的模型检查点文件路径,验证和测试的输出结果也会保存在指定路径中。在测试阶段,我们可以查看模型的输出结果,检查预测效果是否符合预期。 以上就是使用mmsegmentation训练自己的数据集的主要步骤,需要注意的是,这只是一个大致的过程,具体操作会根据自己的数据集和需求有所不同。同时也需要在训练过程中多多尝试和调整,来达到更好的训练效果。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值