mmSegmentation 支持输入多波段图像 并使用随机窗口采样

前言

我需要使用mmSeg做遥感图像分割, 输入的图像有17个波段, 已经将数据转换成npy格式; 标签图像为png格式。
一张图像的大小约为512x512, 共约40张影像。由于数据量较少, 因此希望输入的数据为随机增强裁剪的大小为128x128的图像, 以最大程度利用数据增强的功能。

配置文件基于ISPRSDataset-Vaihingen(configs/_base_/datasets/vaihingen.py)修改而来, mmSeg中对遥感影像处理方式就是直接裁切(tools/convert_datasets/vaihingen.py), 不管你原本多大都给裁成512x512, 然后转成UInt8的png。

版本信息

mmcv 1.4.8
mmsegmentation 0.23.0

支持多波段输入

1 修改数据读取方法

mmseg/datasets/pipelines/loading.py:61 中写死了图片读取方式, 无法支持npy读取。

img_bytes = self.file_client.get(filename)
img = mmcv.imfrombytes(
    img_bytes, flag=self.color_type, backend=self.imdecode_backend)

自定义新的文件加载管道模块 mmseg/dataset/pipline/mloading.py, 并在 mmseg/datasets/pipelines/__init__.py 注册该模块。
主要是将文件读取方式改为了使用mmcv.load()。且使得mmcv.load()支持npy文件读写, 支持内存缓存。

import os.path as osp

import mmcv
import numpy as np

from ..builder import PIPELINES


@register_handler('npy')
class NpyHandler(BaseFileHandler):
    str_like = False

    def load_from_fileobj(self, file, **kwargs):
        return np.load(file)

    # 主要是提供了默认的rb模式
    def load_from_path(self, filepath, **kwargs):
        return super(NpyHandler, self).load_from_path(
            filepath, mode='rb', **kwargs)

    def dump_to_fileobj(self, obj, file, **kwargs):
        np.save(file, obj)

    # 主要是提供了默认的wb模式
    def dump_to_path(self, obj, filepath, **kwargs):
        super(NpyHandler, self).dump_to_path(
            obj, filepath, mode='wb', **kwargs)

    def dump_to_str(self, obj, **kwargs):
        return obj.tobytes()


@PIPELINES.register_module()
class LoadMultiResolutionImageFromFile(object):
    """Load an image from file.

    Required keys are "img_prefix" and "img_info" (a dict that must contain the
    key "filename"). Added or updated keys are "filename", "img", "img_shape",
    "ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`),
    "scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1).

    Args:
        to_float32 (bool): Whether to convert the loaded image to a float32
            numpy array. If set to False, the loaded image is an uint8 array.
            Defaults to False.
        color_type (str): The flag argument for :func:`mmcv.imfrombytes`.
            Defaults to 'color'.
        file_client_args (dict): Arguments to instantiate a FileClient.
            See :class:`mmcv.fileio.FileClient` for details.
            Defaults to ``dict(backend='disk')``.
        imdecode_backend (str): Backend for :func:`mmcv.imdecode`. Default:
            'cv2'
    """

    def __init__(self,
                 to_float32=False,
                 file_client_args=dict(backend='disk'),):
        self.to_float32 = to_float32
        self.file_client_args = file_client_args.copy()
        self.buffer = {}

    def __call__(self, results):
        """Call functions to load image and get image meta information.

        Args:
            results (dict): Result dict from :obj:`mmseg.CustomDataset`.

        Returns:
            dict: The dict contains loaded image and meta information.
        """

        if results.get('img_prefix') is not None:
            filename = osp.join(results['img_prefix'],
                                results['img_info']['filename'])
        else:
            filename = results['img_info']['filename']

        # 缓存
        if self.file_client_args['backend'] == 'mem':
            if filename not in self.buffer:
                self.buffer[filename] = mmcv.load(filename)
            img = self.buffer[filename].copy()
        else:
            img = mmcv.load(filename)

        if self.to_float32:
            img = img.astype(np.float32)
        # replace nan with 0
        img[np.isnan(img)] = 0

        results['filename'] = filename
        results['ori_filename'] = results['img_info']['filename']
        results['img'] = img
        results['img_shape'] = img.shape
        results['ori_shape'] = img.shape
        # Set initial values for default meta_keys
        results['pad_shape'] = img.shape
        results['scale_factor'] = 1.0
        num_channels = 1 if len(img.shape) < 3 else img.shape[2]
        results['img_norm_cfg'] = dict(
            mean=np.zeros(num_channels, dtype=np.float32),
            std=np.ones(num_channels, dtype=np.float32),
            to_rgb=False)
        return results

    def __repr__(self):
        repr_str = self.__class__.__name__
        repr_str += f'(to_float32={self.to_float32},'
        return repr_str


@PIPELINES.register_module()
class LoadMultiResolutionAnnotations(object):
    """Load annotations for semantic segmentation.

    Args:
        reduce_zero_label (bool): Whether reduce all label value by 1.
            Usually used for datasets where 0 is background label.
            Default: False.
        file_client_args (dict): Arguments to instantiate a FileClient.
            See :class:`mmcv.fileio.FileClient` for details.
            Defaults to ``dict(backend='disk')``.
        imdecode_backend (str): Backend for :func:`mmcv.imdecode`. Default:
            'pillow'
    """

    def __init__(self,
                 map_labels=None,
                 file_client_args=dict(backend='disk'),
                 imdecode_backend='pillow',):
        self.file_client_args = file_client_args.copy()
        self.file_client = None
        self.imdecode_backend = imdecode_backend

        self.map_labels = map_labels
        self.use_men_buffer = False
        if self.file_client_args['backend'] == 'mem':
            self.use_men_buffer = True
            self.buffer = {}
            self.file_client_args['backend'] = 'disk'
        if self.map_labels:
            # avoid using underflow conversion
            self.valid_label = [k for k, v in self.map_labels.items() if len(v) == 1]
            self.invalid_label = [k for k in self.map_labels.keys() if k not in self.valid_label]

    def __call__(self, results):
        """Call function to load multiple types annotations.

        Args:
            results (dict): Result dict from :obj:`mmseg.CustomDataset`.

        Returns:
            dict: The dict contains loaded semantic segmentation annotations.
        """

        if self.file_client is None:
            self.file_client = mmcv.FileClient(**self.file_client_args)

        if results.get('seg_prefix', None) is not None:
            filename = osp.join(results['seg_prefix'],
                                results['ann_info']['seg_map'])
        else:
            filename = results['ann_info']['seg_map']

        # 缓存
        if self.use_men_buffer:
            if filename not in self.buffer:
                img_bytes = self.file_client.get(filename)
                self.buffer[filename] = mmcv.imfrombytes(
                    img_bytes, flag='unchanged',
                    backend=self.imdecode_backend).squeeze().astype(np.uint8)
            gt_semantic_seg = self.buffer[filename].copy()
        else:
            img_bytes = self.file_client.get(filename)
            gt_semantic_seg = mmcv.imfrombytes(
                img_bytes, flag='unchanged',
                backend=self.imdecode_backend).squeeze().astype(np.uint8)

        # modify if custom classes
        if results.get('label_map', None) is not None:
            # Add deep copy to solve bug of repeatedly
            # replace `gt_semantic_seg`, which is reported in
            # https://github.com/open-mmlab/mmsegmentation/pull/1445/
            gt_semantic_seg_copy = gt_semantic_seg.copy()
            for old_id, new_id in results['label_map'].items():
                gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id
        # reduce zero_label
        if self.map_labels:
            # avoid using underflow conversion
            for i in self.invalid_label:
                gt_semantic_seg[gt_semantic_seg == i] = 255
            for i, j in enumerate(self.valid_label):
                gt_semantic_seg[gt_semantic_seg == j] = i
        results['gt_semantic_seg'] = gt_semantic_seg
        results['seg_fields'].append('gt_semantic_seg')
        return results

    def __repr__(self):
        repr_str = self.__class__.__name__
        repr_str += f'(reduce_zero_label={self.reduce_zero_label},'
        repr_str += f"imdecode_backend='{self.imdecode_backend}')"
        return repr_str

2 停用测光畸变

在配置文件(数据集配置)的 train_pipeline 中停用测光畸变(PhotoMetricDistortion)

PhotoMetricDistortion有如下功能:

  1. 随机亮度
  2. 随机对比度(模式0)
  3. 将颜色从BGR转换为HSV
  4. 随机的饱和
  5. 随机的颜色
  6. 将颜色从HSV转换为BGR
  7. 随机对比度(模式1)

一般多波段图像用不上

3 设置归一化参数

注意修改归一化参数中的均值和方差, 并将to_rgb设为False

img_norm_cfg = dict(
    mean=[1.06484505e+02,  7.85403773e-02,  4.19996008e-02,  3.67640592e-02,
          3.69739607e-02,  4.02020179e-02,  1.57359848e+02,  1.94448936e+00,
          1.09478876e-01, -1.86033428e-01,  9.28575218e-01,  4.71301109e-01,
          3.62822413e-01,  2.71480083e+00,  4.32536316e+01, -2.82691270e-02,
          1.88922745e+02],
    std=[5.0498714e+01, 1.3906981e-01, 7.4547209e-02, 6.1701167e-02,
         6.1981481e-02, 7.5684831e-02, 2.8444088e+01, 4.5054030e+00,
         3.6696383e-01, 9.7925723e-01, 3.6990447e+00, 4.6588024e-01,
         4.5534578e-01, 2.7089839e+00, 2.9127163e+01, 9.0455323e-01,
         2.8571490e+02],
    to_rgb=False)

4 修改模型输入通道数

在配置文件(模型配置)的backbone中添加/修改输入通道数in_chans

model = dict(
    type='EncoderDecoder',  # 分割器(segmentor)的名字
    pretrained=None,    # 预训练主干网络
    backbone=dict(
        type='DWNet',   # 主干网络的类别
        in_chans=17,	# <- 输入通道数
        embed_dim=96,
        depths=[2, 2, 6, 2],    # 主干网络的深度
        window_size=7,

应用高级图像增强

随机窗口采样(训练时)

train_pipeline = [
    dict(type='LoadMultiResolutionImageFromFile', file_client_args=dict(backend='mem')),
    dict(type='LoadMultiResolutionAnnotations', map_labels=map_labels, file_client_args=dict(backend='mem')),
    dict(type='Resize', img_scale=img_origin, ratio_range=(0.5, 2.0), keep_ratio=False),   # 缩放图片
    dict(type='RandomRotate', prob=1, degree=(-180, 180)),  # 随机旋转
    dict(type='RandomCrop', crop_size=img_scale, cat_max_ratio=1.),   # 随机裁剪图片(若图片>crop_size,则裁剪), 保证图中任意一种类别占比小于cat_max_ratio
    dict(type='RandomFlip', prob=0.5),  # 随机翻转图片
    dict(type='Normalize', **img_norm_cfg),  # 归一化图片
    dict(type='Pad', size=img_scale, pad_val=0, seg_pad_val=255),   # 防止图像过小
    dict(type='DefaultFormatBundle'),   # 通道前置, ToTensor, DataContainer包装
    dict(type='Collect', keys=['img', 'gt_semantic_seg']),  # 指定传入下一步的数据
]
test_pipeline = [
    dict(type='LoadMultiResolutionImageFromFile', file_client_args=dict(backend='mem')),
    dict(
        type='MultiScaleFlipAug',
        img_scale=img_scale,
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=False),
            dict(type='Normalize', **img_norm_cfg),
            dict(type='Pad', size=img_scale, pad_val=0, seg_pad_val=255),   # 防止图像过小
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img'], meta_keys=('filename', 'ori_filename', 'ori_shape',
                                                          'img_shape', 'pad_shape', 'flip', 'img_norm_cfg')),
        ])
]

先基于原始图像随机缩放+旋转
在这里插入图片描述
再裁剪一个固定大小的窗口
在这里插入图片描述

存在的问题

mmseg在训练时将图像和标签一起读进pipline, 但是训练时只有图像进pipline, 标签不走pipline, 因此在验证和测试时会出现原图和经过裁剪的图像放在一起进行比对的情况, 因此上述代码在test_pipeline中只使用resize, 这样至少一定程度上缓解了这个问题。

滑动窗口验证(验证时)

由于上述存在的问题, 根据mmSegmentation 自定义验证钩子(eval_hooks)自定义滑动窗口方法进行验证。
test_pipeline 不对数据进行缩放操作, 并且提供必要的参数(meta_keys)。

test_pipeline = [
    dict(type='LoadMultiResolutionImageFromFile', file_client_args=dict(backend='mem')),
    dict(
        type='MultiScaleFlipAug',
        img_scale=img_origin,
        flip=False,
        transforms=[
            dict(type='Normalize', **img_norm_cfg),
            dict(type='Pad', size=img_origin, pad_val=0, seg_pad_val=255),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img'], meta_keys=('filename', 'ori_filename', 'ori_shape',
                                                          'img_shape', 'pad_shape', 'flip', 'img_norm_cfg')),
        ])
]
evaluation = dict(interval=100, pre_eval=True, save_best='acc')  # 每 100个迭代周期进行一次评估,并保存最优结果

my_eval_hooks.py

import os.path as osp
import warnings

import numpy as np
import torch
import torch.distributed as dist
from mmcv import print_log
from mmcv.runner import DistEvalHook as _DistEvalHook
from mmcv.runner import EvalHook as _EvalHook
from torch.nn.modules.batchnorm import _BatchNorm


class EvalHook(_EvalHook):
    """Single GPU EvalHook, with efficient test support.

    Args:
        by_epoch (bool): Determine perform evaluation by epoch or by iteration.
            If set to True, it will perform by epoch. Otherwise, by iteration.
            Default: False.
        efficient_test (bool): Whether save the results as local numpy files to
            save CPU memory during evaluation. Default: False.
        pre_eval (bool): Whether to use progressive mode to evaluate model.
            Default: False.
    Returns:
        list: The prediction results.
    """

    greater_keys = ['mIoU', 'mAcc', 'aAcc']

    def __init__(self,
                 *args,
                 by_epoch=False,
                 efficient_test=False,
                 pre_eval=False,
                 **kwargs):
        super().__init__(*args, by_epoch=by_epoch, **kwargs)
        self.pre_eval = pre_eval
        if efficient_test:
            warnings.warn(
                'DeprecationWarning: ``efficient_test`` for evaluation hook '
                'is deprecated, the evaluation hook is CPU memory friendly '
                'with ``pre_eval=True`` as argument for ``single_gpu_test()`` '
                'function')

    def _do_evaluate(self, runner):
        """perform evaluation and save ckpt."""
        if not self._should_evaluate(runner):
            return

        runner.model.eval()

        results = []
        for batch_indices, data in zip(self.dataloader.batch_sampler, self.dataloader):
            # 滑动窗口计算每张图片的预测结果
            label = self.dataloader.dataset.get_gt_seg_map_by_idx(batch_indices[0])
            with torch.no_grad():
                img_scale = tuple(runner.data_loader._dataloader.dataset[0]['img'].data.shape[1:])
                result = np.zeros_like(label)
                for i, l, c in self.sliding_window(data, label, img_scale):
                    result[c[0]:c[1], c[2]:c[3]] = runner.model(return_loss=False, **i)[0]
            C = self.dataloader.dataset.confusion_matrix(result, label)
            results.append(C)
        C = np.sum(results, 0)

        # 计算其他指标
        oa = self.dataloader.dataset.c2oa(C)
        pac = self.dataloader.dataset.c2pac(C)
        uac = self.dataloader.dataset.c2uac(C)
        kappa = self.dataloader.dataset.c2kappa(C)
        iou = self.dataloader.dataset.c2iou(C)
        miou = np.mean(iou)

        s = '\n'
        s += '-'*40 + '\n'
        s += 'Evaluation:\n'
        s += '-'*40 + '\n'
        s += 'OA   :\t{:.4f}\n'.format(oa)
        s += 'Kappa:\t{:.4f}\n'.format(kappa)
        s += 'mIoU :\t{:.4f}\n'.format(miou)
        s += '-'*40 + '\n'
        s += 'Producer_acc & User_acc & IoU:\n'
        for i, class_name in enumerate(self.dataloader.dataset.CLASSES):
            s += '{:^20}|{:^6.4f}|{:^6.4f}|{:^6.4f}\n'.format(class_name, pac[i], uac[i], iou[i])
        s += '-'*40 + '\n'
        print_log(s, logger=runner.logger)

        runner.log_buffer.clear()
        runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)

        # 保存最优结果
        if self.save_best:
            self._save_ckpt(runner, oa)

    def sliding_window(self, data, label, img_scale):
        ori_shape = data['img_metas'][0].data[0][0]['img_shape']
        channel_num = data['img_metas'][0].data[0][0]['img_shape'][-1]
        meta_shape = (*img_scale, channel_num)
        data['img_metas'][0].data[0][0]['img_shape'] = meta_shape
        data['img_metas'][0].data[0][0]['pad_shape'] = meta_shape
        data['img_metas'][0].data[0][0]['ori_shape'] = meta_shape
        img = data['img'][0].clone()
        lab = label.copy()
        for y in range(0, ori_shape[0], img_scale[0]):
            if y + img_scale[0] > ori_shape[0]:
                y = ori_shape[0] - img_scale[0]
            for x in range(0, ori_shape[1], img_scale[1]):
                if x + img_scale[1] > ori_shape[1]:
                    x = ori_shape[1] - img_scale[1]
                c = [y, y+img_scale[0], x, x+img_scale[1]]
                data['img'][0] = img[:, :, c[0]:c[1], c[2]:c[3]]
                label = lab[c[0]:c[1], c[2]:c[3]]
                yield data, label, c

吐槽

这个框架整的我很麻, 对自定义数据输入流程拓展非常不友好, 其中的一些解耦方式导致写拓展无从下手, 还有一些实现方式导致代码可读性很差, 使得这框架的学习成本过高。可能也是因为这框架还未开发完成的原因, 毕竟版本号才 0.23.0, 希望能尽快完善吧。

  • 3
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值