SegNet算法详解

SegNet论文详解

SegNet算法Pytorch实现:https://github.com/codecat0/CV/tree/main/Semantic_Segmentation/SegNet

本文提出了一种用于语义分割的深度全卷积神经网络结构SegNet,其核心由一个编码器网络和一个对应的解码器网络以及一个像素级分类层组成

本文的创新在于:
解码器使用在对应编码器的最大池化步骤中计算的池化索引来执行非线性上采样,这与反卷积相比,减少了参数量和运算量,而且消除了学习上采样的需要。
在这里插入图片描述

1. 网络结构

在这里插入图片描述

1.1 编码器

  1. Conv层
    • 通过卷积提取特征,其中使用的是same padding的卷积,不会改变特征图的尺寸
  2. BN层
    • 起到归一化的作用
  3. ReLU层
    • 起到激活函数的作用
  4. Pooling层
    • max pooling层,同时会记录最大值的索引位置

1.2 解码器

  1. Upsampling层
    在这里插入图片描述

    • 对输入的特征图放大两倍,然后把输入特征图的数据根据编码器pooling层的索引位置放入,其他位置为0
  2. Conv层

    • 通过卷积提取特征,其中使用的是same padding的卷积,不会改变特征图的尺寸
  3. BN层

    • 起到归一化的作用
  4. ReLU层

    • 起到激活函数的作用

1.3 像素级分类层

输出每一个像素点在所有类别概率,其中最大的概率类别为该像素的预测值

2. Pytorch实现

import torch
import torch.nn as nn
import torch.nn.functional as F


class Encoder(nn.Module):
    def __init__(self, in_channels):
        super(Encoder, self).__init__()

        batchNorm_momentum = 0.1

        self.encode1 = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64, momentum=batchNorm_momentum),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64, momentum=batchNorm_momentum),
            nn.ReLU(inplace=True),
        )

        self.encode2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(128, momentum=batchNorm_momentum),
            nn.ReLU(inplace=True),

            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(128, momentum=batchNorm_momentum),
            nn.ReLU(inplace=True),
        )

        self.encode3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(256, momentum=batchNorm_momentum),
            nn.ReLU(inplace=True),

            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(256, momentum=batchNorm_momentum),
            nn.ReLU(inplace=True),
        )

        self.encode4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(512, momentum=batchNorm_momentum),
            nn.ReLU(inplace=True),

            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(512, momentum=batchNorm_momentum),
            nn.ReLU(inplace=True),
        )

        self.encode5 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(512, momentum=batchNorm_momentum),
            nn.ReLU(inplace=True),

            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(512, momentum=batchNorm_momentum),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        idx = []

        x = self.encode1(x)
        x, id1 = F.max_pool2d_with_indices(x, kernel_size=2, stride=2, return_indices=True)
        idx.append(id1)

        x = self.encode2(x)
        x, id2 = F.max_pool2d_with_indices(x, kernel_size=2, stride=2, return_indices=True)
        idx.append(id2)

        x = self.encode3(x)
        x, id3 = F.max_pool2d_with_indices(x, kernel_size=2, stride=2, return_indices=True)
        idx.append(id3)

        x = self.encode4(x)
        x, id4 = F.max_pool2d_with_indices(x, kernel_size=2, stride=2, return_indices=True)
        idx.append(id4)

        x = self.encode5(x)
        x, id5 = F.max_pool2d_with_indices(x, kernel_size=2, stride=2, return_indices=True)
        idx.append(id5)

        return x, idx


class Decoder(nn.Module):
    def __init__(self, out_channels):
        super(Decoder, self).__init__()

        batchNorm_momentum = 0.1

        self.decode1 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(512, momentum=batchNorm_momentum),
            nn.ReLU(inplace=True),

            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(512, momentum=batchNorm_momentum),
            nn.ReLU(inplace=True)
        )

        self.decode2 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(512, momentum=batchNorm_momentum),
            nn.ReLU(inplace=True),

            nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(256, momentum=batchNorm_momentum),
            nn.ReLU(inplace=True)
        )

        self.decode3 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(256, momentum=batchNorm_momentum),
            nn.ReLU(inplace=True),

            nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(128, momentum=batchNorm_momentum),
            nn.ReLU(inplace=True)
        )

        self.decode4 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(128, momentum=batchNorm_momentum),
            nn.ReLU(inplace=True),

            nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64, momentum=batchNorm_momentum),
            nn.ReLU(inplace=True)
        )

        self.decode5 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64, momentum=batchNorm_momentum),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, out_channels, kernel_size=3, stride=1, padding=1)
        )

    def forward(self, x, idx):
        x = F.max_unpool2d(x, idx[4], kernel_size=2, stride=2)

        x = self.decode1(x)
        x = F.max_unpool2d(x, idx[3], kernel_size=2, stride=2)

        x = self.decode2(x)
        x = F.max_unpool2d(x, idx[2], kernel_size=2, stride=2)

        x = self.decode3(x)
        x = F.max_unpool2d(x, idx[1], kernel_size=2, stride=2)

        x = self.decode4(x)
        x = F.max_unpool2d(x, idx[0], kernel_size=2, stride=2)

        x = self.decode5(x)

        return x


class SegNet(nn.Module):
    # https://arxiv.org/abs/1511.00561
    def __init__(self, num_classes):
        super(SegNet, self).__init__()

        self.encode = Encoder(in_channels=3)
        self.decode = Decoder(out_channels=num_classes)

    def forward(self, x):
        x, idx = self.encode(x)
        x = self.decode(x, idx)
        return x


if __name__ == '__main__':
    input = torch.randn(1, 3, 384, 544)
    model = SegNet(num_classes=2)
    output = model(input)
    print(output.shape)
  • 7
    点赞
  • 74
    收藏
    觉得还不错? 一键收藏
  • 9
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值