【论文阅读】SegNet算法详解

本文详细介绍了SegNet深度全卷积神经网络的结构和工作原理,该模型由编码器、解码器和像素级分类层组成。编码器使用卷积、批量归一化和ReLU激活函数,而解码器利用编码器的池化索引进行非线性上采样,减少了参数和计算量。Pytorch实现中,定义了Encoder和Decoder类,分别对应编码和解码过程,最后通过像素级分类层输出每个像素的类别概率。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

SegNet论文详解

SegNet算法Pytorch实现:https://github.com/codecat0/CV/tree/main/Semantic_Segmentation/SegNet

本文提出了一种用于语义分割的深度全卷积神经网络结构SegNet,其核心由一个编码器网络和一个对应的解码器网络以及一个像素级分类层组成

本文的创新在于:
解码器使用在对应编码器的最大池化步骤中计算的池化索引来执行非线性上采样,这与反卷积相比,减少了参数量和运算量,而且消除了学习上采样的需要。
在这里插入图片描述

1. 网络结构

在这里插入图片描述

1.1 编码器

  1. Conv层
    • 通过卷积提取特征,其中使用的是same padding的卷积,不会改变特征图的尺寸
  2. BN层
    • 起到归一化的作用
  3. ReLU层
    • 起到激活函数的作用
  4. Pooling层
    • max pooling层,同时会记录最大值的索引位置

1.2 解码器

  1. Upsampling层
    在这里插入图片描述

    • 对输入的特征图放大两倍,然后把输入特征图的数据根据编码器pooling层的索引位置放入,其他位置为0
  2. Conv层

    • 通过卷积提取特征,其中使用的是same padding的卷积,不会改变特征图的尺寸
  3. BN层

    • 起到归一化的作用
  4. ReLU层

    • 起到激活函数的作用

1.3 像素级分类层

输出每一个像素点在所有类别概率,其中最大的概率类别为该像素的预测值

2. Pytorch实现

import torch
import torch.nn as nn
import torch.nn.functional as F


class Encoder(nn.Module):
    def __init__(self, in_channels):
        super(Encoder, self).__init__()

        batchNorm_momentum = 0.1

        self.encode1 = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64, momentum=batchNorm_momentum),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64, momentum=batchNorm_momentum),
            nn.ReLU(inplace=True),
        )

        self.encode2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(128, momentum=batchNorm_momentum),
            nn.ReLU(inplace=True),

            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(128, momentum=batchNorm_momentum),
            nn.ReLU(inplace=True),
        )

        self.encode3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(256, momentum=batchNorm_momentum),
            nn.ReLU(inplace=True),

            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(256, momentum=batchNorm_momentum),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(256, momentum=batchNorm_momentum),
            nn.ReLU(inplace=True),
        )

        self.encode4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(512, momentum=batchNorm_momentum),
            nn.ReLU(inplace=True),

            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(512, momentum=batchNorm_momentum),
            nn.ReLU(inplace=True),
            
			nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(512, momentum=batchNorm_momentum),
            nn.ReLU(inplace=True),
        )

        self.encode5 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(512, momentum=batchNorm_momentum),
            nn.ReLU(inplace=True),

            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(512, momentum=batchNorm_momentum),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(512, momentum=batchNorm_momentum),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        idx = []

        x = self.encode1(x)
        x, id1 = F.max_pool2d_with_indices(x, kernel_size=2, stride=2, return_indices=True)
        idx.append(id1)

        x = self.encode2(x)
        x, id2 = F.max_pool2d_with_indices(x, kernel_size=2, stride=2, return_indices=True)
        idx.append(id2)

        x = self.encode3(x)
        x, id3 = F.max_pool2d_with_indices(x, kernel_size=2, stride=2, return_indices=True)
        idx.append(id3)

        x = self.encode4(x)
        x, id4 = F.max_pool2d_with_indices(x, kernel_size=2, stride=2, return_indices=True)
        idx.append(id4)

        x = self.encode5(x)
        x, id5 = F.max_pool2d_with_indices(x, kernel_size=2, stride=2, return_indices=True)
        idx.append(id5)

        return x, idx


class Decoder(nn.Module):
    def __init__(self, out_channels):
        super(Decoder, self).__init__()

        batchNorm_momentum = 0.1

        self.decode1 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(512, momentum=batchNorm_momentum),
            nn.ReLU(inplace=True),

            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(512, momentum=batchNorm_momentum),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(512, momentum=batchNorm_momentum),
            nn.ReLU(inplace=True),
        )

        self.decode2 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(512, momentum=batchNorm_momentum),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(512, momentum=batchNorm_momentum),
            nn.ReLU(inplace=True),

            nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(256, momentum=batchNorm_momentum),
            nn.ReLU(inplace=True)
        )

        self.decode3 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(256, momentum=batchNorm_momentum),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(256, momentum=batchNorm_momentum),
            nn.ReLU(inplace=True),

            nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(128, momentum=batchNorm_momentum),
            nn.ReLU(inplace=True)
        )

        self.decode4 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(128, momentum=batchNorm_momentum),
            nn.ReLU(inplace=True),

            nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64, momentum=batchNorm_momentum),
            nn.ReLU(inplace=True)
        )

        self.decode5 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64, momentum=batchNorm_momentum),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, out_channels, kernel_size=3, stride=1, padding=1)
        )

    def forward(self, x, idx):
        x = F.max_unpool2d(x, idx[4], kernel_size=2, stride=2)

        x = self.decode1(x)
        x = F.max_unpool2d(x, idx[3], kernel_size=2, stride=2)

        x = self.decode2(x)
        x = F.max_unpool2d(x, idx[2], kernel_size=2, stride=2)

        x = self.decode3(x)
        x = F.max_unpool2d(x, idx[1], kernel_size=2, stride=2)

        x = self.decode4(x)
        x = F.max_unpool2d(x, idx[0], kernel_size=2, stride=2)

        x = self.decode5(x)

        return x


class SegNet(nn.Module):
    # https://arxiv.org/abs/1511.00561
    def __init__(self, num_classes):
        super(SegNet, self).__init__()

        self.encode = Encoder(in_channels=3)
        self.decode = Decoder(out_channels=num_classes)

    def forward(self, x):
        x, idx = self.encode(x)
        x = self.decode(x, idx)
        return x


if __name__ == '__main__':
    input = torch.randn(1, 3, 384, 544)
    model = SegNet(num_classes=2)
    output = model(input)
    print(output.shape)
03-10
### SegNet 架构概述 SegNet 是一种专为图像分割设计的深度卷积神经网络架构,其独特之处在于利用最大池化索引进行上采样,从而有效保留并恢复输入图像的空间信息[^1]。 #### 编码器-解码器结构 SegNet 的整体框架由两部分组成:编码器(Encoder)和解码器(Decoder)。编码器负责提取特征图中的高层次抽象表示;而解码器则通过逐步放大这些特征图来重建原始分辨率下的像素级分类结果。这种对称式的编解码机制有助于保持空间位置的一致性和准确性[^3]。 #### 特征映射与反向传播 在网络训练过程中,编码阶段会记录下每次执行最大池化的精确位置信息作为索引表,在后续解码环节中依据此索引来指导相应的逆运算——即非线性插值过程。这种方法不仅减少了参数量还提高了计算效率,同时也增强了模型对于细粒度物体边界的捕捉能力。 ```python import torch.nn as nn class SegNet(nn.Module): def __init__(self, input_channels=3, output_channels=32): super(SegNet, self).__init__() # Encoder layers with max pooling indices recording self.encoder = ... # Decoder layers using recorded indices for upsampling self.decoder = ... def forward(self, x): pool_indices = [] # Forward pass through encoder saving indices during MaxPool operations encoded_features, pool_indices = self.encoder(x) # Use saved indices within decoder path for precise unpooling steps decoded_output = self.decoder(encoded_features, pool_indices) return decoded_output ``` #### 应用领域 得益于上述特性,SegNet 已经被广泛应用于多个计算机视觉任务当中,尤其是在城市街景解析、医学影像分析等领域表现尤为突出。例如,在自动驾驶辅助系统里可以用来识别道路标志物以及行人车辆等目标对象;而在医疗成像方面可用于肿瘤检测或者器官轮廓勾勒等工作[^2]。
评论 16
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值