图像分割—— Encoder-Decoder Based Models(SegNet)理解和代码分析

Encoder-Decoder Based Models简介:
Encoder-Decoder Based Models是非常经典的一种图像分割模型(FNC网络,严格来说,其实也是这种模式)。模型通常由两部分组成,即Encoder部分和Decoder 部分。Encoder会通过convolutional layers对image进行下采样,使得图片size减小,而channel增加。常见使用的下采样网络有:vgg-16、ResNet。通过去掉下采样网络的fn层,只保留卷积层,从而达到对图片下采样(Encoder)的目的。Decoder 有两种方法:使用deconv对图片上采样,从而将Encoder的特征图Decoder 为像素分割图;对特征图进行unpooling后,做same卷积。
更直观的理解就是,Encoder负责将一张图片的每个像素点,通过复杂的计算过程,映射到某一个高维分布上,而Decoder则是负责将这个高维分布,映射到给定的类别区域。中间的高维分布,是我们不可见的,但神经网络却可以很好的使用它。正是这种借助中间的高维分布的思想,搭建起来了原图像到像素级分类图像的桥梁,实现了end-to-end的训练过程。

SegNet的背景:
SegNet是一个由剑桥大学团队开发的图像分割的开源项目,该项目可以对图像中的物体所在区域进行分割,例如车,马路,行人等,并且精确到像素级别。但是由于传统的SegNet的编码器和解码器都依赖于深层的卷积神经网络,计算量较大,所以在实际使用中,我们为了达到实时分割的目的,通常使用其简化版(对通道数进行了削减)——SegNet Basic

想法和方法:
SegNet过程示意图
SegNet的编码器,采用的是去掉了全连接层的vgg-16网络。在Encoder阶段,每一层通过bn,relu和conv提取特征,通过maxpool进行降维(卷积阶段图片size不变,maxpool阶段图片size缩小一倍)。在进行maxpool的过程中,SegNet创新性的提出了储存maxpool的选择目标的原始位置,留作后面unpooling使用在Decoder 阶段,每一次通过利用前面记录的对应层的maxpool位置进行unpooling,然后对unpooling后的图片进行bn,relu和conv进行特征还原(unpooling阶段图片size增大一倍,卷积阶段图片size不变)。最终将得到和原图像一样大小的图片经过softmax激活,得到像素级分割图。

优点:
文中使用了记录原始位置的maxpool,这在进行unpooling时,相比较随机分配pooling的位置,或者固定分配pooling的位置可以极大的增加像素级分类图像的边缘形状;相比较于传统的deconv上采样,则是直接省略掉了学习的过程(deconv既需要进行上采样的学习,又需要特征还原的学习),使得模型只专注于特征还原的学习。个人认为,这里的maxpool具有传统的skip connect的整合local和global信息的作用,但是却使得计算量下降了一半(传统的 lateral connection后,图像通道数增加了一倍)。

不足:
在SeNet中,是通过从前而后的方法计算出每一个像素点在每个类别的概率,并通过softmax激活(选择概率最大的类别)。由于在编码和解码的过渡中,高维数据不可控,这就会导致一个问题:这种先验概率计算方法,我们无法保证最后softmax的结果是正确的。

改进:
为了解决前面所提到的先验概率结果正确率的不可控性,提出了使用后验概率的Bayesian SegNet:
网图,侵删
Bayesian SegNet相比较于传统的SegNet最大的进步就是,可以给出结果的置信度(这里可以非常明确,Bayesian SegNet对于传统SegNet的性能并没有提升)。结构上看,Bayesian SegNet不过是相对于SegNet添加了一个DropOut层。
DropOut层的工作原理,是在每一轮训练时,对每个神经元添加了一个激活概率(激活时,神经元会得到正常训练;未激活时,神经元不会被训练),这里我们可以将激活看做1,不激活看做0,论文中将激活概率设置为0.5。
由蒙特卡洛抽样可知,当试验次数足够多时,频率可以看做事件发生的概率,因此通过蒙特卡罗抽样,就可以求出一个新分布的均值与方差,这样使用方差大小就可以知道一个分布对于样本的差异性,我们知道方差越大差异越大。反应在图像上,方差大的地方,就是分类不确定性大的地方。
网图,侵删
图片的第一行是分割原图,第二行是ground true,第三行是分割结果,第四行是Bayesian SegNet输出的不确定性图,且颜色越深,不确定性越大。
我们可以看到不确定性大的地方,主要有3点:
1)两种类别的边界处
2)由于遮挡,或者复杂形状而难以识别的物体
3)模糊的分类(如狗和猫,椅子和桌子)

代码分析(以SegNet Basic为例):
看了好多博主给的代码,都有问题,于是自己费了一番力气从GitHub上找了一份自认为还行的。
下面展示模型核心代码

class SegNet(nn.Module):
	'''初始化网络结构'''
    def __init__(self,input_nbr,label_nbr):
        super(SegNet, self).__init__()

        batchNorm_momentum = 0.1

        self.conv11 = nn.Conv2d(input_nbr, 64, kernel_size=3, padding=1)
        self.bn11 = nn.BatchNorm2d(64, momentum= batchNorm_momentum)
        self.conv12 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.bn12 = nn.BatchNorm2d(64, momentum= batchNorm_momentum)

        self.conv21 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn21 = nn.BatchNorm2d(128, momentum= batchNorm_momentum)
        self.conv22 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn22 = nn.BatchNorm2d(128, momentum= batchNorm_momentum)

        self.conv31 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn31 = nn.BatchNorm2d(256, momentum= batchNorm_momentum)
        self.conv32 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.bn32 = nn.BatchNorm2d(256, momentum= batchNorm_momentum)
        self.conv33 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.bn33 = nn.BatchNorm2d(256, momentum= batchNorm_momentum)

        self.conv41 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.bn41 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
        self.conv42 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn42 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
        self.conv43 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn43 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)

        self.conv51 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn51 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
        self.conv52 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn52 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
        self.conv53 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn53 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)

        self.conv53d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn53d = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
        self.conv52d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn52d = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
        self.conv51d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn51d = nn.BatchNorm2d(512, momentum= batchNorm_momentum)

        self.conv43d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn43d = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
        self.conv42d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn42d = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
        self.conv41d = nn.Conv2d(512, 256, kernel_size=3, padding=1)
        self.bn41d = nn.BatchNorm2d(256, momentum= batchNorm_momentum)

        self.conv33d = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.bn33d = nn.BatchNorm2d(256, momentum= batchNorm_momentum)
        self.conv32d = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.bn32d = nn.BatchNorm2d(256, momentum= batchNorm_momentum)
        self.conv31d = nn.Conv2d(256,  128, kernel_size=3, padding=1)
        self.bn31d = nn.BatchNorm2d(128, momentum= batchNorm_momentum)

        self.conv22d = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn22d = nn.BatchNorm2d(128, momentum= batchNorm_momentum)
        self.conv21d = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        self.bn21d = nn.BatchNorm2d(64, momentum= batchNorm_momentum)

        self.conv12d = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.bn12d = nn.BatchNorm2d(64, momentum= batchNorm_momentum)
        self.conv11d = nn.Conv2d(64, label_nbr, kernel_size=3, padding=1)


    def forward(self, x):
		'''网络的计算过程'''
		
		'''Stage 1、2、3、4、5为encoder过程'''
		'''每一个过程可以分解为两次size不变的卷积 + 一次size缩小一半的maxpool'''
        # Stage 1
        x11 = F.relu(self.bn11(self.conv11(x)))
        x12 = F.relu(self.bn12(self.conv12(x11)))
        x1p, id1 = F.max_pool2d(x12,kernel_size=2, stride=2,return_indices=True)

        # Stage 2
        x21 = F.relu(self.bn21(self.conv21(x1p)))
        x22 = F.relu(self.bn22(self.conv22(x21)))
        x2p, id2 = F.max_pool2d(x22,kernel_size=2, stride=2,return_indices=True)
        
        # Stage 3
        x31 = F.relu(self.bn31(self.conv31(x2p)))
        x32 = F.relu(self.bn32(self.conv32(x31)))
        x33 = F.relu(self.bn33(self.conv33(x32)))
        x3p, id3 = F.max_pool2d(x33,kernel_size=2, stride=2,return_indices=True)

        # Stage 4
        x41 = F.relu(self.bn41(self.conv41(x3p)))
        x42 = F.relu(self.bn42(self.conv42(x41)))
        x43 = F.relu(self.bn43(self.conv43(x42)))
        x4p, id4 = F.max_pool2d(x43,kernel_size=2, stride=2,return_indices=True)

        # Stage 5
        x51 = F.relu(self.bn51(self.conv51(x4p)))
        x52 = F.relu(self.bn52(self.conv52(x51)))
        x53 = F.relu(self.bn53(self.conv53(x52)))
        x5p, id5 = F.max_pool2d(x53,kernel_size=2, stride=2,return_indices=True)

		'''Stage 6、7、8、9、10为decoder过程'''
		'''每一个过程可以分解为一次size放大一倍的带index的unpooling + 三次size不变的卷积'''
        # Stage 6
        x5d = F.max_unpool2d(x5p, id5, kernel_size=2, stride=2)
        x53d = F.relu(self.bn53d(self.conv53d(x5d)))
        x52d = F.relu(self.bn52d(self.conv52d(x53d)))
        x51d = F.relu(self.bn51d(self.conv51d(x52d)))

        # Stage 7
        x4d = F.max_unpool2d(x51d, id4, kernel_size=2, stride=2)
        x43d = F.relu(self.bn43d(self.conv43d(x4d)))
        x42d = F.relu(self.bn42d(self.conv42d(x43d)))
        x41d = F.relu(self.bn41d(self.conv41d(x42d)))

        # Stage 8
        x3d = F.max_unpool2d(x41d, id3, kernel_size=2, stride=2)
        x33d = F.relu(self.bn33d(self.conv33d(x3d)))
        x32d = F.relu(self.bn32d(self.conv32d(x33d)))
        x31d = F.relu(self.bn31d(self.conv31d(x32d)))

        # Stage 9
        x2d = F.max_unpool2d(x31d, id2, kernel_size=2, stride=2)
        x22d = F.relu(self.bn22d(self.conv22d(x2d)))
        x21d = F.relu(self.bn21d(self.conv21d(x22d)))

        # Stage 10
        x1d = F.max_unpool2d(x21d, id1, kernel_size=2, stride=2)
        x12d = F.relu(self.bn12d(self.conv12d(x1d)))
        x11d = self.conv11d(x12d)

        return x11d

    def load_from_segnet(self, model_path):
        s_dict = self.state_dict()# create a copy of the state dict
        th = torch.load(model_path).state_dict() # load the weigths
        # for name in th:
            # s_dict[corresp_name[name]] = th[name]
        self.load_state_dict(th)

  • 8
    点赞
  • 34
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值