多光谱图像分割算法

7 篇文章 0 订阅
3 篇文章 0 订阅

ECPN (利用消光一致性感知网络对薄片岩相图像中的颗粒边缘进行分割)

多光谱图像分割算法

在这里插入图片描述

1、输入:多张图像通过通道拼接(Batch_size,C,H,W)以7张图像为例 输入的维度是(1,21,1024,1024)

inputs = torch.rand(1, 21, 1024, 1024)

2、通过通道拼接后的图像通过通道展开,还是按照3通道展开,for循环,将7帧图像的R、G、B单独拿出来

for i in range(3):
	input = (x[:, 0+i, :, :], x[:, 3+i, :, :], x[:, 6+i, :, :], x[:, 9+i, :, :], x[:, 12+i, :, :], x[:, 15+i, :, :], x[:, 18+i, :, :])

3、使用torch.stack(input, dim=1) 沿着1维度堆叠(通道的维度) ,分别堆叠R、G、B

input = torch.stack(input, dim=1)

4、经过SingleConvBlock(7, 28, stride=1, kernel=3, use_bs=True)层,这个SingleConvBlock包含一个卷积和一个批量归一化层 (B,7,H,W)

# 作用:可以实现对特征图的卷积、并在需要时作用批量归一化提高模型的训练效果和稳定性
class SingleConvBlock(nn.Module):
    def __init__(self, in_features, out_features, stride, kernel=1, use_bs=True, padding=1):
        super(SingleConvBlock, self).__init__()
        self.use_bn = use_bs
        self.conv = nn.Conv2d(in_features, out_features, kernel, stride=stride, bias=False, padding=padding)
        self.bn = nn.BatchNorm2d(out_features)

    def forward(self, x):
        x = self.conv(x)
        if self.use_bn:
            x = self.bn(x)
        return x

5、经过SE Block (SELayer),SELayer包含nn.AdaptiveAvgPool2d(1)和nn.Sequential (B,28,H,W)

"""
作用:nn.AdaptiveAvgPool2d(1) 自适应平均池化层,可以有效地将不同大小的输入特征图压缩为相同大小的输出,特征图变成一个全局特征向量。
     nn.Sequential    用于生成一个通道注意力权重向量,这个向量用来调整输入特征图中每个通道的重要性,使得网络能够更好地关注重要的特征。
     SE Block模块通过学习到的权重重新加权各个通道,使得模型能够根据输入数据的特征,自适应地调整通道的权重,从而提升模型的表达能力。
"""
class SELayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

6、经过SingleConvBlock(28, 1, stride=1, kernel=3, use_bs=True)层,这个SingleConvBlock包含一个卷积和一个批量归一化层 (B,1,H,W)

class SingleConvBlock(nn.Module):
    def __init__(self, in_features, out_features, stride, kernel=1, use_bs=True, padding=1):
        super(SingleConvBlock, self).__init__()
        self.use_bn = use_bs
        self.conv = nn.Conv2d(in_features, out_features, kernel, stride=stride, bias=False, padding=padding)
        self.bn = nn.BatchNorm2d(out_features)

    def forward(self, x):
        x = self.conv(x)
        if self.use_bn:
            x = self.bn(x)
        return x

7、最后通过torch.cat(input, dim=1)通过通道进行拼接为原始三位图像 (B,3,H,W)

x = torch.cat(input_list, dim=1)# Bx3xHxW

8、使用tf_efficientnetv2_xl.in21k_ft_in1k做特征提取,这里和原作者差异换了一个backbone,网络都是基于efficientnetv2,只不过这个预训练模型是在ImageNet-21K数据集训练的,在ImageNet-1K微调,参数更大,效果更好一些。

all_pretrained_models_available = timm.list_models(pretrained=True)
config = _cfg(url='', file=r'F:\weights\model.safetensors')
model = timm.create_model('tf_efficientnetv2_xl.in21k_ft_in1k', features_only=True,     						   pretrained=True, pretrained_cfg=config)
self.model = model

9、输出的5个多尺度的特征矩阵经过Bidecoder做多尺度特征融合,这里p5_up和p4_in做了windowAttention(),同样输出5个特征向量

class Bidecoder(nn.Module):
    def __init__(self, onnx_export=False):
        super(Bidecoder, self).__init__()
        self.conv6_up = SeparableConvBlock(256, onnx_export=onnx_export)
        self.conv5_up = SeparableConvBlock(96, onnx_export=onnx_export)
        self.conv4_up = SeparableConvBlock(64, onnx_export=onnx_export)
        self.conv3_up = SeparableConvBlock(32, onnx_export=onnx_export)
        self.conv4_down = SeparableConvBlock(64, onnx_export=onnx_export)
        self.conv5_down = SeparableConvBlock(96, onnx_export=onnx_export)
        self.conv6_down = SeparableConvBlock(256, onnx_export=onnx_export)
        self.conv7_down = SeparableConvBlock(640, onnx_export=onnx_export)
    
        self.p6_upsample = nn.ConvTranspose2d(640, 256, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.p5_upsample = nn.ConvTranspose2d(256, 96, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.p4_upsample = nn.ConvTranspose2d(96, 64, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.p3_upsample = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)
    
        self.p4_downsample = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.p5_downsample = nn.Conv2d(64, 96, kernel_size=3, stride=2, padding=1)
        self.p6_downsample = nn.Conv2d(96, 256, kernel_size=3, stride=2, padding=1)
        self.p7_downsample = nn.Conv2d(256, 640, kernel_size=3, stride=2, padding=1)
    
        self.swish = nn.SiLU() if not onnx_export else nn.Sigmoid()

    def forward(self, inputs):
        """
            P7_0 -------------------------> P7_2 -------->
               |-------------|                ↑
                             ↓                |
            P6_0 ---------> P6_1 ---------> P6_2 -------->
               |-------------|--------------↑ ↑
                             ↓                |
            P5_0 ---------> P5_1 ---------> P5_2 -------->
               |-------------|--------------↑ ↑
                             ↓                |
            P4_0 ---------> P4_1 ---------> P4_2 -------->
               |-------------|--------------↑ ↑
                             |--------------↓ |
            P3_0 -------------------------> P3_2 -------->
        """

        p3_in, p4_in, p5_in, p6_in, p7_in = inputs


        p6_up = self.conv6_up(self.swish(p6_in + self.p6_upsample(p7_in)))
        p5_up = self.conv5_up(self.swish(p5_in + self.p5_upsample(p6_up)))

        B, C, H, W = p5_up.shape
        x = p5_up.permute(0, 2, 3, 1)
        window_size = 8
        self.shift_size = 0
        self.window_size = window_size
        # cyclic shift
        self.attn = WindowAttention(
            96, window_size=to_2tuple(self.window_size), num_heads=4,
            qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.)
        if self.shift_size > 0:
            if not self.fused_window_process:
                shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
                # partition windows
                x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
            else:
                x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size)
        else:
            shifted_x = x
            # partition windows
            x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C

        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows)  # nW*B, window_size*window_size, C

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)

        # reverse cyclic shift
        if self.shift_size > 0:
            if not self.fused_window_process:
                shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C
                x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
            else:
                x = WindowProcessReverse.apply(attn_windows, B, H, W, C, self.shift_size, self.window_size)
        else:
            shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C
            x = shifted_x
        p5_up = x.permute(0, 3, 1, 2)
        p4_up = self.conv4_up(self.swish(p4_in + self.p4_upsample(p5_up)))
        p3_out = self.conv3_up(self.swish(p3_in + self.p3_upsample(p4_up)))
        p4_out = self.conv4_down(self.swish(p4_in + p4_up + self.p4_downsample(p3_out)))
        p5_out = self.conv5_down(self.swish(p5_in + p5_up + self.p5_downsample(p4_out)))
        p6_out = self.conv6_down(self.swish(p6_in + p6_up + self.p6_downsample(p5_out)))
        p7_out = self.conv7_down(self.swish(p7_in + self.p7_downsample(p6_out)))

        return [p3_out, p4_out, p5_out, p6_out, p7_out]

10、输出的结果通过UpConvBlock_subpixel(上采样(Upsampling)的卷积块)特别适用于图像超分辨率等任务。该块的关键是利用子像素卷积(Subpixel Convolution)进行上采样,同时结合了卷积和非线性激活。

class UpConvBlock_subpixel(nn.Module):
    def __init__(self, in_features, up_scale, out_channels):
        super(UpConvBlock_subpixel, self).__init__()
        self.conv_for_subpixel = nn.Conv2d(in_features, out_channels=out_channels, kernel_size=1)
        self.subpixel_conv = nn.PixelShuffle(up_scale)
        self.swish = MemoryEfficientSwish()

    def forward(self, x):
        x = self.conv_for_subpixel(x)
        x = self.swish(x)
        x = self.subpixel_conv(x)
        return x

11、经过采样后的特征进行torch.cat(result, dmi=1)根据通道进行拼接 (B,5,1024, 1024)

block_cat = torch.cat(results, dim=1)  # Bx5xHxW

12、SingleConvBlock(5, 1, stride=1, use_bs=False, padding=0),这个SingleConvBlock包含一个卷积和一个批量归一化层 (B,1,H,W)

'''
SingleConvBlock 是一个简单但功能强大的神经网络模块,它包括一个卷积层和一个可选的批量归一化层。
通过这种方式,可以实现对输入特征图的卷积操作,并在需要时应用批量归一化以提高模型的训练效果和稳定性
'''
class SingleConvBlock(nn.Module):
    def __init__(self, in_features, out_features, stride, kernel=1, use_bs=True, padding=1):
        super(SingleConvBlock, self).__init__()
        self.use_bn = use_bs
        self.conv = nn.Conv2d(in_features, out_features, kernel, stride=stride, bias=False, padding=padding)
        self.bn = nn.BatchNorm2d(out_features)

    def forward(self, x):
        x = self.conv(x)
        if self.use_bn:
            x = self.bn(x)
        return x

13、返回结果,特征学习的过程结束,如果训练直接和标注结果mask图计算损失即可

result= block_cat

参考论文:
The edge segmentation of grains in thin-section petrographic images
utilising extinction consistency perception network

  • 23
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值