Pytorch:Unet网络代码详解

  pytorch版本的Unet网络可以去github上面下载,网址为https://github.com/milesial/Pytorch-UNet,话不多说,还是以代码为例吧。
有小伙伴问我pytorch的型号,发图给大家参考一下,文章写得有点久了…好多东西我自己都记不太清楚了,体谅一下~
在这里插入图片描述

1、dataset.py

  这个数据集采用的是汽车的数据集,数据集当中返回的是一个字典:

        return {
            'image': torch.from_numpy(img).type(torch.FloatTensor),
            'mask': torch.from_numpy(mask).type(torch.FloatTensor)
        }

  image返回的则是汽车的图片,如下图:

  mask则返回的是图层蒙版,如下图:

2、Unet模型

  代码分为Unet_model.py以及Unet_part.py
  Unet网络图如下所示:

  再看一下网络大体的代码结构:
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear 

  n_classes:希望获得的每个像素的概率数,对于一个类和背景,使用n_classes=1,这里输出的就是黑白对照,所以使用1;n_channels=3是因为输入的图片是RGB 图像,因此是三维;bilinear则用于上采样。

        self.inc = DoubleConv(n_channels, 64)

  首先输入一张图片,通过DoubleConv将通道数变为64,图片的大小改变就对应的公式[(n1-n2)/s+1],(其中n1是图片大小,n2是卷积核大小,s是滑动步长,默认为1,因此图片大小由572->570->568),DoubleConv对应的就是下采样中的每一行的卷积与Relu,可以看到每一行的通道数是没有发生改变,找到这部分的代码:

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.double_conv(x)

  DoubleConv主要是用于两次卷积,如果有中间通道的话可以作为桥梁,先卷积到中间通道数,然后再卷积到输出通道数。

        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)

  再看一下下采样的后续过程,找到Down代码:

class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Down,self).__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

  主要是先进行最大池化,将图片大小变为原来的一半,然后再采用DoubleConv增加通道数。这样经过3次下采样,可以看到图片通道数为512,大小为64*64,此时还需要进行第4次下采样,由于后续要进行上采样,需要将每一层上采样对应的特征图与下采样对应的特征图进行融合,能够充分获得有用信息,融合时需要通道数进行对应,因此输出通道数为512,且图片大小为28*28,对应的forward代码为:

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

  接着看一下后续上采样初始化的代码:

        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)

  上采样则采用了bilinear,看一下Up的代码:

class Up(nn.Module):
    """Upscaling then double conv"""
    def __init__(self, in_channels, out_channels, bilinear=True):
        super(Up,self).__init__()
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

  在nn.Upsample函数中,scale_factor指定输出大小为输入的多少倍数,mode:可使用的上采样算法,align_corners为True,输入的角像素将与输出张量对齐,因此将保存下来这些像素的值,nn.ConvTranspose2d是反卷积,对卷积层进行上采样,使其回到原始图片的分辨率。而对应的forward代码为:

  def forward(self, x1, x2):
        x1 = self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)#进行融合裁剪
        return self.conv(x)

  上采样的过程中需要对两个特征图进行融合,通道数一样并且尺寸也应该一样,x1是上采样获得的特征,而x2是下采样获得的特征,首先对x1进行反卷积使其大小变为输入时的2倍,首先需要计算两张图长宽的差值,作为填补padding的依据,由于此时图片的表示为(C,H,W),因此diffY对应的图片的高,diffX对应图片的宽度, F.pad指的是(左填充,右填充,上填充,下填充),其数值代表填充次数,因此需要/2,最后进行融合剪裁。
  上采样所对应的forword代码:

        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)

  以第一层上采样为例,x5对应的是最后一次下采样获得的图片,通道数为512,大小为28*28,x4是第三次下采样获得的图片,通道为512,大小为64*64,首先将x5的特征大小变为2倍为56*56,然后长宽差距为8,所以周围分别补4个0,再和x4进行竖向拼接,因此输出通道数为1024,大小为64*64,然后就继续进行三次上采样,最终获得的图片通道为64,大小为572(跟图不符,但是用过程是没问题的,用代码测试过了),此时就已经变成了跟原来图片大小了,接着:

	self.outc = OutConv(64, n_classes)

  看一下OutConv代码:

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

  仍然采用了两次卷积,此时由于卷积核为1*1大小,因此不改变图片的大小,forword代码为:

		logits = self.outc(x)

  看一下每一层输出的结果吧,设原始图片大小为[1, 3, 572, 572]:

输入图片: torch.Size([1, 3, 572, 572])
下采样x1: torch.Size([1, 64, 572, 572])
下采样x2: torch.Size([1, 128, 286, 286])
下采样x3: torch.Size([1, 256, 143, 143])
下采样x4: torch.Size([1, 512, 71, 71])
下采样x5: torch.Size([1, 512, 35, 35])
上采样x4: torch.Size([1, 256, 71, 71])
上采样x3: torch.Size([1, 128, 143, 143])
上采样x2: torch.Size([1, 64, 286, 286])
上采样x1: torch.Size([1, 64, 572, 572])
输出图片: torch.Size([1, 1, 572, 572])

  此时输出的则是黑白图片了,黑白图片plt输出要压缩到2维才行。
  尝试进行一下输出:

  • 44
    点赞
  • 299
    收藏
    觉得还不错? 一键收藏
  • 21
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 21
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值