【语义分割网络系列】三、SegNet


参考资料

论文

  SegNet: A Deep Convolutional Encoder-Decoder Architecture for Image Segmentation

博客

  [论文笔记] SegNet: Encoder-Decoder Architecture


第1章 前言

 作者先谈了之前的一些工作,认为虽然有很多方法,但是都很粗糙,其主要原因在于max pooling and sub-sampling reduce feature map resolution。而作者设计SegNet就是为了解决在语义分割中将低分辨率映射到原输入的分辨率上的问题。

 解码器是使用编码器中的maxpool的像素索引来进行反池化,从而免去了学习上采样的需要,这个idea是来自于无监督特征学习。在解码器中重新使用编码器池化时的索引下标有这么几个优点:

  • 能改善边缘的情况;
  • 减少了模型的参数;
  • 这种能容易就能整合到任何的编码器-解码器结构中,只需要稍稍改动。

 SegNet的灵感来源于场景理解应用。因此,它能有效地在推理期间减少内存占用和增加计算效率。与其他model相比,它参数数量也要少得多,并且可以使用随机梯度下降进行端到端的训练。

 最主要的贡献如下:

(1)明确提出了编码器-解码器架构

(2)提出了maxpool索引来解码的方法,节省了内存

在这里插入图片描述


第2章 SegNet网络结构

2.1 Encoder

  • 在编码器处,执行卷积和最大池化。
  • VGG-16有13个卷积层。(不用全连接的层)
  • 在进行2×2最大池化时,存储相应的最大池化索引(位置)。

 SegNet研究团队认为编码器下采样过程中图像信息损失较多,直接存储所有卷积块的特征图又非常占用内存,因而在SegNet中提出在每一次最大池化下采样前存储最大池化的位置索引( Max-pooling indices ),即记住最大池化操作中,最大值在 2 × 2 2\times2 2×2 池化窗口中的位置。每个 2 × 2 2\times2 2×2 窗口仅需要 2 bits内存存储量,这种池化位置索引可用于上采样解码时恢复图像信息。

在这里插入图片描述


2.2 Decoder

  • 在解码器处,执行上采样和卷积。最后,每个像素送到softmax分类器。
  • 在上采样期间,如上所示,调用相应编码器层处的最大池化索引以进行上采样,FCN中是利用双线性插值,再直接叠加上一层的特征图。
  • 最后,使用K类softmax分类器来预测每个像素的类别。

在这里插入图片描述


第3章 结论

 SegNet背后的主要动机是需要为道路和室内场景理解设计一个高效的网络,同时保证在内存和计算时间方面都是高效的。在实验中并将其与其他重要的变体网络进行了比较,以揭示在设计分割网络时所涉及的权衡,特别是训练时间、内存占用与准确性的权衡。完全存储编码器网络特征映射的架构性能最佳,但在推理期间消耗更多内存。另一方面,SegNet的效率更高,因为它只存储特征映射的maxpool索引,并在其解码器网络中使用它们来实现良好的性能。


第4章 Pytorch实现SegNet

参考

  GITHUB上有一个简单明了的PyTorch实现


import torch
import torch.nn as nn
import torch.nn.functional as F


class SegNet(nn.Module):
    def __init__(self, input_nbr, label_nbr):
        super(SegNet, self).__init__()

        batchNorm_momentum = 0.1

        self.conv11 = nn.Conv2d(input_nbr, 64, kernel_size=3, padding=1)
        self.bn11 = nn.BatchNorm2d(64, momentum=batchNorm_momentum)
        self.conv12 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.bn12 = nn.BatchNorm2d(64, momentum=batchNorm_momentum)

        self.conv21 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn21 = nn.BatchNorm2d(128, momentum=batchNorm_momentum)
        self.conv22 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn22 = nn.BatchNorm2d(128, momentum=batchNorm_momentum)

        self.conv31 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn31 = nn.BatchNorm2d(256, momentum=batchNorm_momentum)
        self.conv32 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.bn32 = nn.BatchNorm2d(256, momentum=batchNorm_momentum)
        self.conv33 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.bn33 = nn.BatchNorm2d(256, momentum=batchNorm_momentum)

        self.conv41 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.bn41 = nn.BatchNorm2d(512, momentum=batchNorm_momentum)
        self.conv42 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn42 = nn.BatchNorm2d(512, momentum=batchNorm_momentum)
        self.conv43 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn43 = nn.BatchNorm2d(512, momentum=batchNorm_momentum)

        self.conv51 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn51 = nn.BatchNorm2d(512, momentum=batchNorm_momentum)
        self.conv52 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn52 = nn.BatchNorm2d(512, momentum=batchNorm_momentum)
        self.conv53 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn53 = nn.BatchNorm2d(512, momentum=batchNorm_momentum)

        self.conv53d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn53d = nn.BatchNorm2d(512, momentum=batchNorm_momentum)
        self.conv52d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn52d = nn.BatchNorm2d(512, momentum=batchNorm_momentum)
        self.conv51d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn51d = nn.BatchNorm2d(512, momentum=batchNorm_momentum)

        self.conv43d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn43d = nn.BatchNorm2d(512, momentum=batchNorm_momentum)
        self.conv42d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn42d = nn.BatchNorm2d(512, momentum=batchNorm_momentum)
        self.conv41d = nn.Conv2d(512, 256, kernel_size=3, padding=1)
        self.bn41d = nn.BatchNorm2d(256, momentum=batchNorm_momentum)

        self.conv33d = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.bn33d = nn.BatchNorm2d(256, momentum=batchNorm_momentum)
        self.conv32d = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.bn32d = nn.BatchNorm2d(256, momentum=batchNorm_momentum)
        self.conv31d = nn.Conv2d(256, 128, kernel_size=3, padding=1)
        self.bn31d = nn.BatchNorm2d(128, momentum=batchNorm_momentum)

        self.conv22d = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn22d = nn.BatchNorm2d(128, momentum=batchNorm_momentum)
        self.conv21d = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        self.bn21d = nn.BatchNorm2d(64, momentum=batchNorm_momentum)

        self.conv12d = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.bn12d = nn.BatchNorm2d(64, momentum=batchNorm_momentum)
        self.conv11d = nn.Conv2d(64, label_nbr, kernel_size=3, padding=1)

    def forward(self, x):
        # Stage 1
        x11 = F.relu(self.bn11(self.conv11(x)))
        x12 = F.relu(self.bn12(self.conv12(x11)))
        x1p, id1 = F.max_pool2d(x12, kernel_size=2, stride=2, return_indices=True)

        # Stage 2
        x21 = F.relu(self.bn21(self.conv21(x1p)))
        x22 = F.relu(self.bn22(self.conv22(x21)))
        x2p, id2 = F.max_pool2d(x22, kernel_size=2, stride=2, return_indices=True)

        # Stage 3
        x31 = F.relu(self.bn31(self.conv31(x2p)))
        x32 = F.relu(self.bn32(self.conv32(x31)))
        x33 = F.relu(self.bn33(self.conv33(x32)))
        x3p, id3 = F.max_pool2d(x33, kernel_size=2, stride=2, return_indices=True)

        # Stage 4
        x41 = F.relu(self.bn41(self.conv41(x3p)))
        x42 = F.relu(self.bn42(self.conv42(x41)))
        x43 = F.relu(self.bn43(self.conv43(x42)))
        x4p, id4 = F.max_pool2d(x43, kernel_size=2, stride=2, return_indices=True)

        # Stage 5
        x51 = F.relu(self.bn51(self.conv51(x4p)))
        x52 = F.relu(self.bn52(self.conv52(x51)))
        x53 = F.relu(self.bn53(self.conv53(x52)))
        x5p, id5 = F.max_pool2d(x53, kernel_size=2, stride=2, return_indices=True)

        # Stage 5d
        x5d = F.max_unpool2d(x5p, id5, kernel_size=2, stride=2)
        x53d = F.relu(self.bn53d(self.conv53d(x5d)))
        x52d = F.relu(self.bn52d(self.conv52d(x53d)))
        x51d = F.relu(self.bn51d(self.conv51d(x52d)))

        # Stage 4d
        x4d = F.max_unpool2d(x51d, id4, kernel_size=2, stride=2)
        x43d = F.relu(self.bn43d(self.conv43d(x4d)))
        x42d = F.relu(self.bn42d(self.conv42d(x43d)))
        x41d = F.relu(self.bn41d(self.conv41d(x42d)))

        # Stage 3d
        x3d = F.max_unpool2d(x41d, id3, kernel_size=2, stride=2)
        x33d = F.relu(self.bn33d(self.conv33d(x3d)))
        x32d = F.relu(self.bn32d(self.conv32d(x33d)))
        x31d = F.relu(self.bn31d(self.conv31d(x32d)))

        # Stage 2d
        x2d = F.max_unpool2d(x31d, id2, kernel_size=2, stride=2)
        x22d = F.relu(self.bn22d(self.conv22d(x2d)))
        x21d = F.relu(self.bn21d(self.conv21d(x22d)))

        # Stage 1d
        x1d = F.max_unpool2d(x21d, id1, kernel_size=2, stride=2)
        x12d = F.relu(self.bn12d(self.conv12d(x1d)))
        x11d = self.conv11d(x12d)

        return x11d

    def load_from_segnet(self, model_path):
        s_dict = self.state_dict()  # create a copy of the state dict
        th = torch.load(model_path).state_dict()  # load the weigths
        # for name in th:
        # s_dict[corresp_name[name]] = th[name]
        self.load_state_dict(th)
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
### 回答1: 要使用SegNet PyTorch版本来训练自己的数据集,需要按照以下步骤进行操作。 首先,将自己的数据集准备好。数据集应包含带有相应标签的图像。确保所有图像的分辨率一致,并且标签图像与输入图像大小相匹配。 接下来,下载SegNet PyTorch版本的源代码,并配置所需的环境。PyTorch的安装是必需的,你可以根据自己的系统进行安装。此外,还需要安装其他可能需要的依赖项。 然后,将准备好的数据集分为训练集和测试集。确保训练集与测试集的标签图像都包含在对应的文件夹中,并且文件名与其对应的输入图像相同。 接下来,修改SegNet源代码以适应自己的数据集。在训练和测试过程中,需要根据数据集的类别数量修改网络的输出通道数,并根据输入图像的大小调整网络的输入尺寸。 在修改好源代码后,进行训练。使用训练集数据来训练网络,并调整超参数以达到更好的性能。可以通过调节批次大小、学习率和迭代次数等来调整训练速度和准确性。 训练完成后,可以使用测试集数据来评估网络的性能。查看网络在测试集上每个类别的预测结果,并计算准确性、精确度和召回率等评价指标。 最后,可以使用训练好的SegNet模型来对未知图像进行预测。加载模型并对待预测图像进行处理,最后得到图像的分割结果。 以上就是使用SegNet PyTorch版本训练自己的数据集的基本步骤。通过适应自己的数据集和调整超参数,可以获得更好的语义分割模型。 ### 回答2: SegNet是一种用于图像语义分割深度学习模型,其可以用于将输入图像分为不同的语义类别。如果要在PyTorch中使用SegNet模型,需要先准备自己的数据集并对其进行相应的处理。 首先,数据集需要包括输入图像和对应的标签图像。输入图像作为模型的输入,标签图像包含每个像素的语义类别信息。可以使用图像标注工具如labelImg对图像进行手动标注,或者使用已有的语义标注数据集。 接下来,需要将数据集分为训练集和验证集。可以按照一定的比例将数据集划分为两部分,其中一部分用于模型的训练,另一部分用于验证模型的性能。 然后,需要对数据集进行预处理。预处理的步骤包括图像的缩放、归一化和图像增强等。在PyTorch中,使用torchvision.transforms中的函数可以方便地进行这些处理。 接下来,需要定义数据加载器。可以使用PyTorch的DataLoader类读取预处理后的数据集,并将其提供给模型进行训练和验证。 在开始训练之前,需要加载SegNet模型。在PyTorch中,可以通过torchvision.models中的函数加载预定义的SegNet模型。可以选择预训练好的模型权重,或者将模型初始化为随机权重。 然后,需要定义损失函数和优化器。对于语义分割问题,常用的损失函数是交叉熵损失函数。可以使用torch.nn.CrossEntropyLoss定义损失函数。优化器可以选择Adam或SGD等常用的优化算法。 最后,开始模型的训练和验证。使用torch.nn.Module类创建SegNet模型的子类,并实现其forward函数。然后,通过迭代训练集的每个批次,使用损失函数计算损失,并使用优化器更新模型的参数。在每个epoch结束后,使用验证集评估模型的性能。 以上就是在PyTorch中使用SegNet模型进行图像语义分割的基本流程。通过按照上述步骤对自己的数据集进行处理,即可使用SegNet模型训练和验证自己的图像语义分割任务。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

travellerss

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值