4-图像分割之SegNet&DeconvNet

本文介绍了编码器-解码器架构在图像分割任务中的重要性,如SegNet和DeconvNet。SegNet通过重用最大池化索引实现精确分割,而DeconvNet则采用反卷积和反池化,两者都在减少过拟合的同时恢复图像的细节。PyTorch实现展示了这两个模型的结构和前向传播过程,强调了特征图尺寸管理和信息恢复的重要性。
摘要由CSDN通过智能技术生成

1.预备知识

编码器和解码器:

编码器结构:
编码器部分主要由普通卷积层和下采样层将特征图尺寸缩小,使其成为更低维的表征。目的是尽可能多的提取低级特征和高级特征,从而利用提取到的空间信息和全局信息精确分割。
解码器结构:
解码器部分主要由普通卷积、上采样层和融合层组成。利用上采样操作逐步恢复空间维度,融合编码过程中提取到的特征,在尽可能减少信息损失的前提下完成同尺寸输入输出。

随机丢弃层:

当一个复杂的前馈神经网络被训练在小的数据集时,容易造成过拟合。为了防止过拟合,可以通过阻止特征检测器的共同作用来提高神经网络的性能。Dropout可以作为训练深度神经网络的一种技巧供选择。在每个训练批次中,通过忽略一半的特征检测器(让一半的隐层节点值为0),可以明显地减少过拟合现象。这种方式可以减少特征检测器(隐层节点)间的相互作用。
 

 反池化:

上采样通常是两种方式,一种是通过插值的方式实现,另外一种是通过反卷积实现。这里引入第三种实现方式:反池化。

编码器中的每一个最大池化层的索引都存储了起来,用于之后在解码器中使用那些存储的索引来对相应特征图进行去池化操作。这有助于保持高频信息的完整性,但当对低分辨率的特征图进行反池化时,它也会忽略邻近的信息。
 

 2.研究成果及意义

SegNet:

  • 在内存(参数)和准确率之间找到了很好地平衡点
  • 将编码解码结构普适化
  • 在多个场景数据集中均取得了很好的结果

Segnet包含编码器和解码器,解码器上采样会利用编码期间用到的池化信息。Segnet在camvid数据上达到60.10的miou值,目前65-70算是尚可的模型,论文中有关实验的部分非常饱满。

个人概括SegNet要点:

  • 编码解码结构
  • 重用max_polling indices(池化索引)

网络结构:

 

由图可知,网络是完全对称的,输入图片经过五次卷积池化,然后是五次上采样卷积,值得注意的是利用了下采样池化过程中的信息,从而映射回上采样的过程中。 

DeconvNet:

提出深度反卷积网络,编码部分使用vgg16,解码器使用反卷积和反池化,训练是把图像一块一块区域的进行输入,需要一定人为的干预。

网络结构:

 同样也是经过不断的卷积池化,然后上采样。与Segnet所不同的有两点:

  • 中间加入了两个全卷积(现在看来是不太可取,一方面增加了参数量,另一方面将原来二维的分割变成了一维的向量)
  • 在上采样的过程中采用了反卷积而不是卷积(这里反卷积经过设计并不会扩大特征图的尺寸,而且作者认为反卷积能够生成更密集的特征图)
  • 论文写法逻辑经典

3.Pytorch实现SegNet&DeconvNet

SegNet.py

import torch
from torch import nn
import torch.nn.functional as F
import torchvision.models as models
import warnings
warnings.filterwarnings("ignore")

vgg16_pretrained = models.vgg16(pretrained=True)  # 加载预训练权重


def decoder(input_channel, output_chanel, num=3):  # num用来控制是3层卷积还是2层卷积
    if num == 3:
        decoder_body = nn.Sequential(
            nn.Conv2d(input_channel, input_channel, 3, padding=1),
            nn.Conv2d(input_channel, input_channel, 3, padding=1),
            nn.Conv2d(input_channel, output_chanel, 3, padding=1),
        )
    elif num == 2:
        decoder_body = nn.Sequential(
            nn.Conv2d(input_channel, input_channel, 3, padding=1),
            nn.Conv2d(input_channel, output_chanel, 3, padding=1)
        )

    return decoder_body


class VGG16_SegNet(torch.nn.Module):
    def __init__(self):
        super(VGG16_SegNet, self).__init__()

        pool_list = [4, 9, 16, 23, 30]
        for index in pool_list:
            vgg16_pretrained.features[index].return_indices = True  # 需要用到池化层的索引

        self.encoder1 = vgg16_pretrained.features[:4]
        self.pool1 = vgg16_pretrained.features[4]

        self.encoder2 = vgg16_pretrained.features[5:9]
        self.pool2 = vgg16_pretrained.features[9]

        self.encoder3 = vgg16_pretrained.features[10:16]
        self.pool3 = vgg16_pretrained.features[16]

        self.encoder4 = vgg16_pretrained.features[17:23]
        self.pool4 = vgg16_pretrained.features[23]

        self.encoder5 = vgg16_pretrained.features[24:30]
        self.pool5 = vgg16_pretrained.features[30]

        self.decoder5 = decoder(512, 512)
        self.unpool5 = nn.MaxUnpool2d(2, 2)

        self.decoder4 = decoder(512, 256)
        self.unpool4 = nn.MaxUnpool2d(2, 2)

        self.decoder3 = decoder(256, 128)
        self.unpool3 = nn.MaxUnpool2d(2, 2)

        self.decoder2 = decoder(128, 64, 2)
        self.unpool2 = nn.MaxUnpool2d(2, 2)

        self.decoder1 = decoder(64, 12, 2)
        self.unpool1 = nn.MaxUnpool2d(2, 2)

    def forward(self, x):  # 3,352,480
        print('input_image:', x.size())
        encoder1 = self.encoder1(x);print('encoder1:', encoder1.size())  # 64,352,480
        output_size1 = encoder1.size();print('output_size1:', output_size1)  # 64,352,480,保留尺寸
        pool1, indices1 = self.pool1(encoder1);print('pool1:', pool1.size());print('indices1:', indices1.size())  # 64,176,240

        encoder2 = self.encoder2(pool1);print('encoder2:', encoder2.size())  # 128,176,240
        output_size2 = encoder2.size();print('output_size2:', output_size2)  # 128,176,240,保留尺寸
        pool2, indices2 = self.pool2(encoder2);print('pool2:', pool2.size());print('indices2:', indices2.size())  # 128,88,120

        encoder3 = self.encoder3(pool2);print('encoder3:', encoder3.size())  # 256,88,120
        output_size3 = encoder3.size();print('output_size3:', output_size3)  # 256,88,120,保留尺寸
        pool3, indices3 = self.pool3(encoder3);print('pool3:', pool3.size());print('indices3:', indices3.size())  # 256,44,60

        encoder4 = self.encoder4(pool3);print('encoder4:', encoder4.size())  # 512,44,60
        output_size4 = encoder4.size();print('output_size4:', output_size4)  # 512,44,60,保留尺寸
        pool4, indices4 = self.pool4(encoder4);print('pool4:', pool4.size());print('indices4:', indices4.size())  # 512,22,30

        encoder5 = self.encoder5(pool4);print('encoder5:', encoder5.size())  # 512,22,30
        output_size5 = encoder5.size();print('output_size5:', output_size5)  # 512,22,30,保留尺寸
        pool5, indices5 = self.pool5(encoder5);print('pool5:', pool5.size());print('indices5:', indices5.size())
        # 512,11,15.pool5和indices5尺寸相同

        unpool5 = self.unpool5(input=pool5, indices=indices5, output_size=output_size5);print('unpool5:', unpool5.size())  # 512,22,30
        decoder5 = self.decoder5(unpool5);print('decoder5:', decoder5.size())  # 512,22,30

        unpool4 = self.unpool4(input=decoder5, indices=indices4, output_size=output_size4);print('unpool4:', unpool4.size())  # 512,44,60
        decoder4 = self.decoder4(unpool4);print('decoder4:', decoder4.size())  # 256,44,60

        unpool3 = self.unpool3(input=decoder4, indices=indices3, output_size=output_size3);print('unpool3:', unpool3.size())  # 256,88,120
        decoder3 = self.decoder3(unpool3);print('decoder3:', decoder3.size())  # 128,88,120

        unpool2 = self.unpool2(input=decoder3, indices=indices2, output_size=output_size2);print('unpool2:', unpool2.size())  # 128,176,240
        decoder2 = self.decoder2(unpool2);print('decoder2:', decoder2.size())  # 64,176,240

        unpool1 = self.unpool1(input=decoder2, indices=indices1, output_size=output_size1);print('unpool1:', unpool1.size())  # 64,352,480
        decoder1 = self.decoder1(unpool1);print('decoder1:', decoder1.size())  # 12,352,480

        return decoder1


if __name__ == "__main__":
    import torch as t

    rgb = t.randn(1, 3, 352, 480)
    net = VGG16_SegNet()
    out = net(rgb)
    print(out.shape)

 DeconvNet.py

import torch
import torchvision.models as models
from torch import nn
from PIL import Image
vgg16_pretrained = models.vgg16(pretrained=True)
import warnings
warnings.filterwarnings("ignore")

def decoder(input_channel, output_channel, num=3):
    if num == 3:
        decoder_body = nn.Sequential(
            nn.ConvTranspose2d(input_channel, input_channel, 3, padding=1),
            nn.ConvTranspose2d(input_channel, input_channel, 3, padding=1),
            nn.ConvTranspose2d(input_channel, output_channel, 3, padding=1))
    elif num == 2:
        decoder_body = nn.Sequential(
            nn.ConvTranspose2d(input_channel, input_channel, 3, padding=1),
            nn.ConvTranspose2d(input_channel, output_channel, 3, padding=1))

    return decoder_body


class VGG16_DeconvNet(torch.nn.Module):
    def __init__(self):
        super(VGG16_DeconvNet, self).__init__()

        pool_list = [4, 9, 16, 23, 30]
        for index in pool_list:
            vgg16_pretrained.features[index].return_indices = True

        self.encoder1 = vgg16_pretrained.features[:4]
        self.pool1 = vgg16_pretrained.features[4]

        self.encoder2 = vgg16_pretrained.features[5:9]
        self.pool2 = vgg16_pretrained.features[9]

        self.encoder3 = vgg16_pretrained.features[10:16]
        self.pool3 = vgg16_pretrained.features[16]

        self.encoder4 = vgg16_pretrained.features[17:23]
        self.pool4 = vgg16_pretrained.features[23]

        self.encoder5 = vgg16_pretrained.features[24:30]
        self.pool5 = vgg16_pretrained.features[30]

        self.classifier = nn.Sequential(
            torch.nn.Linear(512 * 11 * 15, 4096),
            torch.nn.ReLU(),
            torch.nn.Linear(4096, 512 * 11 * 15),
            torch.nn.ReLU(),
        )

        self.decoder5 = decoder(512, 512)
        self.unpool5 = nn.MaxUnpool2d(2, 2)

        self.decoder4 = decoder(512, 256)
        self.unpool4 = nn.MaxUnpool2d(2, 2)

        self.decoder3 = decoder(256, 128)
        self.unpool3 = nn.MaxUnpool2d(2, 2)

        self.decoder2 = decoder(128, 64, 2)
        self.unpool2 = nn.MaxUnpool2d(2, 2)

        self.decoder1 = decoder(64, 12, 2)
        self.unpool1 = nn.MaxUnpool2d(2, 2)

    def forward(self, x):
        encoder1 = self.encoder1(x)
        output_size1 = encoder1.size()
        pool1, indices1 = self.pool1(encoder1)

        encoder2 = self.encoder2(pool1)
        output_size2 = encoder2.size()
        pool2, indices2 = self.pool2(encoder2)

        encoder3 = self.encoder3(pool2)
        output_size3 = encoder3.size()
        pool3, indices3 = self.pool3(encoder3)

        encoder4 = self.encoder4(pool3)
        output_size4 = encoder4.size()
        pool4, indices4 = self.pool4(encoder4)

        encoder5 = self.encoder5(pool4)
        output_size5 = encoder5.size()
        pool5, indices5 = self.pool5(encoder5)
        print('pool5:', pool5.size())

        pool5=pool5.view(pool5.size(0),-1)
        print('pool5-view:', pool5.size())
        fc=self.classifier(pool5)
        print('fc1:', fc.size())
        fc=fc.reshape(1,512,11,15)
        print('fc2:', fc.size())

        unpool5 = self.unpool5(input=fc, indices=indices5, output_size=output_size5);
        #print('unpool5:', unpool5.size())  # 512,22,30
        decoder5 = self.decoder5(unpool5);
        #print('decoder5:', decoder5.size())  # 512,22,30

        unpool4 = self.unpool4(input=decoder5, indices=indices4, output_size=output_size4);
        #print('unpool4:', unpool4.size())  # 512,44,60
        decoder4 = self.decoder4(unpool4);
        #print('decoder4:', decoder4.size())  # 256,44,60

        unpool3 = self.unpool3(input=decoder4, indices=indices3, output_size=output_size3);
        #print('unpool3:', unpool3.size())  # 256,88,120
        decoder3 = self.decoder3(unpool3);
        #print('decoder3:', decoder3.size())  # 128,88,120

        unpool2 = self.unpool2(input=decoder3, indices=indices2, output_size=output_size2);
        #print('unpool2:', unpool2.size())  # 128,176,240
        decoder2 = self.decoder2(unpool2);
        #print('decoder2:', decoder2.size())  # 64,176,240

        unpool1 = self.unpool1(input=decoder2, indices=indices1, output_size=output_size1);
        #print('unpool1:', unpool1.size())  # 64,352,480
        decoder1 = self.decoder1(unpool1);
        #print('decoder1:', decoder1.size())  # 12,352,480

        return decoder1
if __name__ == "__main__":
    import torch as t

    rgb = t.randn(1, 3, 352, 480)
    net = VGG16_DeconvNet()
    out = net(rgb)
    print(out.shape)

DeconvNet里面用到了全连接层,在特征图的尺寸方面需要注意,参见下面这段代码

import torch

x = torch.randn(2, 84480)  # 输入的维度是(2,84480)
print('x:', x.size())

m1 = torch.nn.Linear(512 * 11 * 15, 4096)
output1 = m1(x)
print('output1:', output1.size())

m2 = torch.nn.Linear(4096, 512 * 11 * 15)
output2 = m2(output1)
print('output2:', output2.size())
output2=output2.reshape(2,512,11,15)
print('output2:', output2.size())

(2,84480) 对应的是pool5.view之后的特征图尺寸,不能直接把pool5之后的特征图送入nn.linear(),会报尺寸不匹配的错误,因为nn.linear()接受的是二维张量,pool5之后的特征图是四维张量,另外in_features输入的应该是size而不是batch_size*size,例如上面的([2,84480]),对应输入到全连接层in_features应该是512 * 11 * 15,经过m1层之后尺寸变为([2,4096]),再经过m2又变成([2, 84480])。

 参考:

B站深度之眼

PyTorch的nn.Linear()详解 - douzujun - 博客园

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值