分割任务-特征端融合代码-mmsegentation框架,以k-net为例

1.mmseg/dataset下新建自己的my_datatset1,记得在init中也初始化一下

# Copyright (c) OpenMMLab. All rights reserved.
from mmseg.registry import DATASETS
from .basesegdataset import BaseCDDataset


@DATASETS.register_module()
class Mydataset_1(BaseCDDataset):
    """ISPRS Potsdam dataset.

    In segmentation map annotation for Potsdam dataset, 0 is the ignore index.
    ``reduce_zero_label`` should be set to True. The ``img_suffix`` and
    ``seg_map_suffix`` are both fixed to '.png'.
    """
    METAINFO = dict(
        classes=('nowater', 'water'),
        palette=[[0, 0, 0], [0 ,255 , 255]])

    # def __init__(self,
    #              img_suffix='.tif',
    #              seg_map_suffix='.png',
    #              reduce_zero_label=False,
    #              **kwargs) -> None:
    #     super().__init__(
    #         img_suffix=img_suffix,
    #         seg_map_suffix=seg_map_suffix,
    #         reduce_zero_label=reduce_zero_label,
    #         **kwargs)

    def __init__(self,
                 img_suffix='.tif',
                 img_suffix2='.tif',
                 seg_map_suffix='.png',
                 reduce_zero_label=False,
                 **kwargs) -> None:
        super().__init__(
            img_suffix=img_suffix,
            img_suffix2=img_suffix2,
            seg_map_suffix=seg_map_suffix,
            reduce_zero_label=reduce_zero_label,
            # ann_file=ann_file,
            **kwargs)
        # assert fileio.exists(self.data_prefix['img_path'],
        #                      self.backend_args)

2.config中的base/dataset:数据路径改成自己的数据,这里为输入的两组数据为均为128X128x3。

# dataset settings

dataset_type = 'Mydataset_1'

data_root = '/root/data1/Track2'

crop_size = (512, 512)

train_pipeline = [

    dict(type='LoadMultipleRSImageFromFile'),

    dict(type='LoadAnnotations', reduce_zero_label=False),

    dict(

        type='RandomResize',

        scale=(512, 512),

        ratio_range=(0.5, 2.0),

        keep_ratio=True),

    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),

    dict(type='RandomFlip', prob=0.5),

    # dict(type='PhotoMetricDistortion'),

    dict(type='PackSegInputs')

]

val_pipeline = [

    dict(type='LoadMultipleRSImageFromFile'),

    dict(type='Resize', scale=(512, 512), keep_ratio=True),

    # add loading annotation after ``Resize`` because ground truth

    # does not need to do resize data transform

    dict(type='LoadAnnotations', reduce_zero_label=False),

    dict(type='PackSegInputs')

]

test_pipeline = [

    dict(type='LoadMultipleRSImageFromFile'),

    dict(type='Resize', scale=(512, 512), keep_ratio=True),

    # add loading annotation after ``Resize`` because ground truth

    # does not need to do resize data transform

    dict(type='LoadAnnotations', reduce_zero_label=False),

    dict(type='PackSegInputs')

]

# img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]

img_ratios = [4.0, 4.5, 5.0, 5.0, 5.5, 6.0]

tta_pipeline = [

    dict(type='LoadMultipleRSImageFromFile'),

    # dict(type='LoadImageFromFile',backend_args=None ),

    dict(

        type='TestTimeAug',

        transforms=[

            [

                dict(type='Resize', scale_factor=r, keep_ratio=True)

                for r in img_ratios

            ],

            [

                dict(type='RandomFlip', prob=0., direction='horizontal'),

                dict(type='RandomFlip', prob=1., direction='horizontal')

            ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]

        ])

]

train_dataloader = dict(

    batch_size=4,

    num_workers=2,

    persistent_workers=True,

    sampler=dict(type='InfiniteSampler', shuffle=True),

    dataset=dict(

        type=dataset_type,

        data_root=data_root,

        data_prefix=dict(

            img_path='/root/data1/Track2/train/images',

            # img_path2='/root/data1/Track2/train/train_cdem',

            img_path2='/root/data1/Track2/train/train_rgb',

            seg_map_path='/root/data1/Track2/train/labels'),

        pipeline=train_pipeline))

val_dataloader = dict(

    batch_size=1,

    num_workers=2,

    persistent_workers=True,

    sampler=dict(type='DefaultSampler', shuffle=False),

    dataset=dict(

        type=dataset_type,

        data_root=data_root,

        data_prefix=dict(

            img_path='/root/data1/Track2/val/images',

            # img_path2='/root/data1/Track2/val/val_cdem',

            img_path2='/root/data1/Track2/val/val_rgb',

            seg_map_path='/root/data1/Track2/val/labels'),

        pipeline=val_pipeline))

test_dataloader = dict(

    batch_size=1,

    num_workers=2,

    persistent_workers=True,

    sampler=dict(type='DefaultSampler', shuffle=False),

    dataset=dict(

        type=dataset_type,

        data_root=data_root,

        data_prefix=dict(

            img_path='/root/data1/Track2/test/images',

            img_path2='/root/data1/Track2/test/test_cdem',

            seg_map_path='/root/data1/Track2/test/labels'),

        pipeline=test_pipeline))



val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])

test_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'],save_best='mIoU')

3.改写config中k-net:均值和方差按照自己数据算出来填上,必须对应通道,这里跑20k

_base_ = [
    '../_base_/datasets/flood3.py',
    '../_base_/default_runtime.py',
    '../_base_/schedules/schedule_20k.py'
]
crop_size = (512, 512)
data_preprocessor = dict(
    type='SegDataPreProcessor',
    mean=[123.675, 116.28, 103.53, 123.675, 116.28, 103.53],
    std=[58.395, 57.12, 57.375, 58.395, 57.12, 57.375],
    # +swir2
    # mean=[46.69228311, 102.80634228, 100.59608681,87.7220806],
    # std=[15.8533346, 34.881314, 38.23240675,140.73692343],
    # +water
    # mean=[123.675, 116.28, 103.53,123.675],
    # std=[58.395, 57.12, 57.375,58.395],
    # bgr_to_rgb=True,
    bgr_to_rgb=False,
    pad_val=0,
    size=crop_size,
    seg_pad_val=255)
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
num_stages = 3
conv_kernel_size = 1
model = dict(
    type='EncoderDecoder',
    data_preprocessor=data_preprocessor,
    pretrained='open-mmlab://resnest101',
    backbone=dict(
        type='ResNeSt',
        # in_channels=6,
        in_first_channels=3,
        in_second_channels=3,
        stem_channels=128,
        radix=2,
        reduction_factor=4,
        avg_down_stride=True,
        depth=101,
        num_stages=4,
        out_indices=(0, 1, 2, 3),
        dilations=(1, 1, 2, 4),
        strides=(1, 2, 1, 1),
        norm_cfg=norm_cfg,
        norm_eval=False,
        style='pytorch',
        contract_dilation=True),
    decode_head=dict(
        type='IterativeDecodeHead',
        num_stages=num_stages,
        kernel_update_head=[
            dict(
                type='KernelUpdateHead',
                num_classes=2,
                num_ffn_fcs=2,
                num_heads=8,
                num_mask_fcs=1,
                feedforward_channels=2048,
                in_channels=512,
                out_channels=512,
                dropout=0.0,
                conv_kernel_size=conv_kernel_size,
                ffn_act_cfg=dict(type='ReLU', inplace=True),
                with_ffn=True,
                feat_transform_cfg=dict(
                    conv_cfg=dict(type='Conv2d'), act_cfg=None),
                kernel_updator_cfg=dict(
                    type='KernelUpdator',
                    in_channels=256,
                    feat_channels=256,
                    out_channels=256,
                    act_cfg=dict(type='ReLU', inplace=True),
                    norm_cfg=dict(type='LN'))) for _ in range(num_stages)
        ],
        kernel_generate_head=dict(
            type='FCNHead',
            in_channels=2048,
            in_index=3,
            channels=512,
            num_convs=2,
            concat_input=True,
            dropout_ratio=0.1,
            num_classes=2,
            norm_cfg=norm_cfg,
            align_corners=False,
            loss_decode=dict(
                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))),
    auxiliary_head=dict(
        type='FCNHead',
        in_channels=1024,
        in_index=2,
        channels=256,
        num_convs=1,
        concat_input=False,
        dropout_ratio=0.1,
        num_classes=2,
        norm_cfg=norm_cfg,
        align_corners=False,
        loss_decode=dict(
            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
    # model training and testing settings
    train_cfg=dict(),
    test_cfg=dict(mode='whole'))

# optimizer
optim_wrapper = dict(
    _delete_=True,
    type='OptimWrapper',
    optimizer=dict(type='AdamW', lr=0.0001, weight_decay=0.0005),
    clip_grad=dict(max_norm=1, norm_type=2))
# learning policy
param_scheduler = [
    dict(
        type='LinearLR', start_factor=0.001, by_epoch=False, begin=0,
        end=125*2),
    dict(
        type='MultiStepLR',
        begin=125*2,
        end=10000*2,
        milestones=[7500*2, 9000*2],
        by_epoch=False,
    )
]
# In K-Net implementation we use batch size 2 per GPU as default
train_dataloader = dict(batch_size=4, num_workers=2)
val_dataloader = dict(batch_size=1, num_workers=2)
test_dataloader = val_dataloader
work_dir='/root/data1/work_dir/k-net8_fusion_rgb'

4.在mmseg/datasets/transforms/loading.py里的class LoadMultipleRSImageFromFile中修改

@TRANSFORMS.register_module()
class LoadMultipleRSImageFromFile(BaseTransform):
    """Load two Remote Sensing mage from file.

    Required Keys:

    - img_path
    - img_path2

    Modified Keys:

    - img
    - img2
    - img_shape
    - ori_shape

    Args:
        to_float32 (bool): Whether to convert the loaded image to a float32
            numpy array. If set to False, the loaded image is a float64 array.
            Defaults to True.
    """

    def __init__(self, to_float32: bool = True):
        # if gdal is None:
        #     raise RuntimeError('gdal is not installed')
        self.to_float32 = to_float32

    def transform(self, results: Dict) -> Dict:
        """Functions to load image.

        Args:
            results (dict): Result dict from :obj:``mmcv.BaseDataset``.

        Returns:
            dict: The dict contains loaded image and meta information.
        """

        # filename = results['img_path']
        # filename2 = results['img_path2']

        # ds = gdal.Open(filename)
        # ds2 = gdal.Open(filename2)

        # if ds is None:
        #     raise Exception(f'Unable to open file: {filename}')
        # if ds2 is None:
        #     raise Exception(f'Unable to open file: {filename2}')

        # img = np.einsum('ijk->jki', ds.ReadAsArray())
        # img2 = np.einsum('ijk->jki', ds2.ReadAsArray())

        filename = results['img_path']
        filename2 = results['img_path2']

        img_bytes = fileio.get(filename)
        img = mmcv.imfrombytes(img_bytes, flag='color', backend='cv2')

        img2_bytes = fileio.get(filename2)
        img2 = mmcv.imfrombytes(img2_bytes, flag='color', backend='cv2')

        if self.to_float32:
            img = img.astype(np.float32)
            img2= img2.astype(np.float32)
            # img = np.concatenate((img,img2[:,:,:1]),axis=2)
            img = np.concatenate((img,img2),axis=2)
            # print('img shape:', img.shape)
            # print('img2 shape:', img2.shape)
        results['img'] = img
        results['img_shape'] = img.shape[:2]
        results['ori_shape'] = img.shape[:2]
        return results
        # if self.to_float32:
        #     img = img.astype(np.float32)
        #     img2 = img2.astype(np.float32)

        # if img.shape != img2.shape:
        #     raise Exception(f'Image shapes do not match:'
        #                     f' {img.shape} vs {img2.shape}')

        # results['img'] = img
        # results['img2'] = img2
        # results['img_shape'] = img.shape[:2]
        # results['ori_shape'] = img.shape[:2]
        # return results

    def __repr__(self):
        repr_str = (f'{self.__class__.__name__}('
                    f'to_float32={self.to_float32})')
        return repr_str

6.在mmseg/models/backbones/resnet,这里以resnet为例

class ResNet(BaseModule):
    def __init__(self,
                 depth,
                #  in_channels=3,
                 #特征端融合first和second输入两通道
                 in_first_channels=3,
                 in_second_channels=3,
                 stem_channels=64,
                 base_channels=64,
    .....
#这里增加了初始化输入的两个通道

#修改一下这部分
        # self._make_stem_layer(in_channels, stem_channels)
        self._make_stem_layer( in_first_channels, in_second_channels, stem_channels)

#这个函数复制两份
    def _make_stem_layer(self, in_first_channels, in_second_channels,stem_channels):
        """Make stem layer for ResNet."""
        # 特征端融合
        if self.deep_stem:
            self.stem_first_channels = nn.Sequential(
                build_conv_layer(
                    self.conv_cfg,
                    in_first_channels,
                    stem_channels // 2,
                    kernel_size=3,
                    stride=2,
                    padding=1,
                    bias=False),
                build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
                nn.ReLU(inplace=True),
                build_conv_layer(
                    self.conv_cfg,
                    stem_channels // 2,
                    stem_channels // 2,
                    kernel_size=3,
                    stride=1,
                    padding=1,
                    bias=False),
                build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
                nn.ReLU(inplace=True),
                build_conv_layer(
                    self.conv_cfg,
                    stem_channels // 2,
                    stem_channels,
                    kernel_size=3,
                    stride=1,
                    padding=1,
                    bias=False),
                build_norm_layer(self.norm_cfg, stem_channels)[1],
                nn.ReLU(inplace=True))
            self.stem_second_channels = nn.Sequential(
                build_conv_layer(
                    self.conv_cfg,
                    in_second_channels,
                    stem_channels // 2,
                    kernel_size=3,
                    stride=2,
                    padding=1,
                    bias=False),
                build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
                nn.ReLU(inplace=True),
                build_conv_layer(
                    self.conv_cfg,
                    stem_channels // 2,
                    stem_channels // 2,
                    kernel_size=3,
                    stride=1,
                    padding=1,
                    bias=False),
                build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
                nn.ReLU(inplace=True),
                build_conv_layer(
                    self.conv_cfg,
                    stem_channels // 2,
                    stem_channels,
                    kernel_size=3,
                    stride=1,
                    padding=1,
                    bias=False),
                build_norm_layer(self.norm_cfg, stem_channels)[1],
                nn.ReLU(inplace=True))
        else:
            self.conv1 = build_conv_layer(
                self.conv_cfg,
                in_channels,
                stem_channels,
                kernel_size=7,
                stride=2,
                padding=3,
                bias=False)
            self.norm1_name, norm1 = build_norm_layer(
                self.norm_cfg, stem_channels, postfix=1)
            self.add_module(self.norm1_name, norm1)
            self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)


#forward部分修改
    #特征端融合的代码
    def forward(self, x):
        """Forward function."""
        if self.deep_stem:
            x1 = self.stem_first_channels(x[:, 0:3, :, :])
            x2 = self.stem_second_channels(x[:, 3:6, :, :])
        else:
            x = self.conv1(x)
            x = self.norm1(x)
            x = self.relu(x)
        x = self.maxpool(x)
        outs1 = []
        outs2 = []
        x1 = self.maxpool(x1)
        x2 = self.maxpool(x2)
        for i, layer_name in enumerate(self.res_layers):
            res_layer = getattr(self, layer_name)
            x1 = res_layer(x1)
            x2 = res_layer(x2)
            if i in self.out_indices:
                outs1.append(x1)
                outs2.append(x2)
        outs = []
        for out1, out2 in zip(outs1, outs2):
            outs.append(torch.add(out1, out2)) 
        return tuple(outs)

其他部分基本不需要改动

  • 5
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值