U-Net网络框架

         为了应对秋招,最近在复习之前学过的深度学习算法。之前也复习了其他算法的基础知识和网络框架,但是之前没有想起来写一写博客,今天突然想起来可以写篇博客保存下来,一方面自己可以用于以后得复习,另一方面说不定可以帮助刚接触该算法的同学,有需要的同学可以互相学习下。

        这里主要给出自己和结合他人的代码写的U-Net的网络框架,详细的内容解释请参考最后的参考博客。

        本篇博客中算法也是根据网上其他大佬的笔记总结出来的,如果有问题可以指出互相学习!

一、U-Net基础

        

         整个网络框架全部使用卷积完成特征的提取工作,该网络框架主要分为两部分:左半部分(contracting path,压缩路径)和右半部分(expansive path, 扩展路径),压缩路径主要用来搭建深层网络结构,提取深层次的语义信息,也被称为U-Net的encoder部分。该部分主要分为四个block,每个block包括了两个卷积和一个Max Pooling,其中Max Pooling为下采用操作,每次下采样后feature map的宽和高变为原来的一半。扩展路径同样有四个block,每个包括两个卷积和一个上采样,由于下采用操作会丢失部分细节信息,在上采样中会通过copy and cat操作剪切和拼接浅层的语义信息,这样可以保留较多的语义信息(在拼接操作时注意尺寸大小,需要进行剪裁操作)。

二、主体代码

import torch.nn as nn
from torchvision import transforms
import torch

class Double_convs(nn.Module):
    def __init__(self, in_channels, out_channels) -> None:
        super(Double_convs, self).__init__()
        
        self.block = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=0), 
                                        nn.BatchNorm2d(out_channels), 
                                        nn.ReLU(inplace=True), 
                                        
                                        nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=0), 
                                        nn.BatchNorm2d(out_channels), 
                                        nn.ReLU(inplace=True))
    def forward(self, x):
        x = self.block(x)

        return x

class Block_DownSample(nn.Module):
    def __init__(self) -> None:
        super(Block_DownSample, self).__init__()

        self.block_1 = Double_convs(1, 64)
        self.block_2 = Double_convs(64, 128)
        self.block_3 = Double_convs(128, 256)
        self.block_4 = Double_convs(256, 512)
        self.block_5 = Double_convs(512, 1024)

        self.maxpooling = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
    
    def forward(self, x):
        out_1 = self.block_1(x)

        x = self.maxpooling(out_1)
        out_2 = self.block_2(x)

        x = self.maxpooling(out_2)
        out_3 = self.block_3(x)

        x = self.maxpooling(out_3)
        out_4 = self.block_4(x)

        x = self.maxpooling(out_4)
        out_5 = self.block_5(x)

        return out_1, out_2, out_3, out_4, out_5
    
class Block_UpSample(nn.Module):
    def __init__(self, in_channels, out_channels) -> None:
        super(Block_UpSample, self).__init__()

        self.upsample_1 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 
                                      nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0), 
                                      nn.ReLU(inplace=True))
        self.block_1 = Double_convs(in_channels, out_channels)
        
    def forward(self, x1, x2, wh):
        x1 = self.upsample_1(x1)

        centercrop = transforms.CenterCrop((wh, wh))    # 这里我采用中心裁剪的方式,不知道对不对
        x2 = centercrop(x2)
        x = torch.cat([x1, x2], dim=1)
        x = self.block_1(x)

        return x

class U_net(nn.Module):
    def __init__(self) -> None:
        super(U_net, self).__init__()
        self.down = Block_DownSample()
        self.up_1 = Block_UpSample(1024, 512)
        self.up_2 = Block_UpSample(512, 256)
        self.up_3 = Block_UpSample(256, 128)
        self.up_4 = Block_UpSample(128, 64)
        self.conv = nn.Sequential(nn.Conv2d(64, 2, kernel_size=1))

    def forward(self, x):
        out_1, out_2, out_3, out_4, out_5 = self.down(x)

        x = self.up_1(out_5, out_4, 56)
        x = self.up_2(x, out_3, 104)
        x = self.up_3(x, out_2, 200)
        x = self.up_4(x, out_1, 392)
        x = self.conv(x)

        return x

         看到其他博文中有博主提到在encoder中卷积部分的padding设置成1可以不进行剪切就能直接进行拼接操作,有兴趣的同学可以尝试一下。

        整个网络的输入输出信息如下所示:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param # 
================================================================
            Conv2d-1         [-1, 64, 570, 570]             640
       BatchNorm2d-2         [-1, 64, 570, 570]             128
              ReLU-3         [-1, 64, 570, 570]               0
            Conv2d-4         [-1, 64, 568, 568]          36,928
       BatchNorm2d-5         [-1, 64, 568, 568]             128
              ReLU-6         [-1, 64, 568, 568]               0
      Double_convs-7         [-1, 64, 568, 568]               0
         MaxPool2d-8         [-1, 64, 284, 284]               0
            Conv2d-9        [-1, 128, 282, 282]          73,856
      BatchNorm2d-10        [-1, 128, 282, 282]             256
             ReLU-11        [-1, 128, 282, 282]               0
           Conv2d-12        [-1, 128, 280, 280]         147,584
      BatchNorm2d-13        [-1, 128, 280, 280]             256
             ReLU-14        [-1, 128, 280, 280]               0
     Double_convs-15        [-1, 128, 280, 280]               0
        MaxPool2d-16        [-1, 128, 140, 140]               0
           Conv2d-17        [-1, 256, 138, 138]         295,168
      BatchNorm2d-18        [-1, 256, 138, 138]             512
             ReLU-19        [-1, 256, 138, 138]               0
           Conv2d-20        [-1, 256, 136, 136]         590,080
      BatchNorm2d-21        [-1, 256, 136, 136]             512
             ReLU-22        [-1, 256, 136, 136]               0
     Double_convs-23        [-1, 256, 136, 136]               0
        MaxPool2d-24          [-1, 256, 68, 68]               0
           Conv2d-25          [-1, 512, 66, 66]       1,180,160
      BatchNorm2d-26          [-1, 512, 66, 66]           1,024
             ReLU-27          [-1, 512, 66, 66]               0
           Conv2d-28          [-1, 512, 64, 64]       2,359,808
      BatchNorm2d-29          [-1, 512, 64, 64]           1,024
             ReLU-30          [-1, 512, 64, 64]               0
     Double_convs-31          [-1, 512, 64, 64]               0
        MaxPool2d-32          [-1, 512, 32, 32]               0
           Conv2d-33         [-1, 1024, 30, 30]       4,719,616
      BatchNorm2d-34         [-1, 1024, 30, 30]           2,048
             ReLU-35         [-1, 1024, 30, 30]               0
           Conv2d-36         [-1, 1024, 28, 28]       9,438,208
      BatchNorm2d-37         [-1, 1024, 28, 28]           2,048
             ReLU-38         [-1, 1024, 28, 28]               0
     Double_convs-39         [-1, 1024, 28, 28]               0
 Block_DownSample-40  [[-1, 64, 568, 568], [-1, 128, 280, 280], [-1, 256, 136, 136], [-1, 512, 64, 64], [-1, 1024, 28, 28]]               0
         Upsample-41         [-1, 1024, 56, 56]               0
           Conv2d-42          [-1, 512, 56, 56]         524,800
             ReLU-43          [-1, 512, 56, 56]               0
           Conv2d-44          [-1, 512, 54, 54]       4,719,104
      BatchNorm2d-45          [-1, 512, 54, 54]           1,024
             ReLU-46          [-1, 512, 54, 54]               0
           Conv2d-47          [-1, 512, 52, 52]       2,359,808
      BatchNorm2d-48          [-1, 512, 52, 52]           1,024
             ReLU-49          [-1, 512, 52, 52]               0
     Double_convs-50          [-1, 512, 52, 52]               0
   Block_UpSample-51          [-1, 512, 52, 52]               0
         Upsample-52        [-1, 512, 104, 104]               0
           Conv2d-53        [-1, 256, 104, 104]         131,328
             ReLU-54        [-1, 256, 104, 104]               0
           Conv2d-55        [-1, 256, 102, 102]       1,179,904
      BatchNorm2d-56        [-1, 256, 102, 102]             512
             ReLU-57        [-1, 256, 102, 102]               0
           Conv2d-58        [-1, 256, 100, 100]         590,080
      BatchNorm2d-59        [-1, 256, 100, 100]             512
             ReLU-60        [-1, 256, 100, 100]               0
     Double_convs-61        [-1, 256, 100, 100]               0
   Block_UpSample-62        [-1, 256, 100, 100]               0
         Upsample-63        [-1, 256, 200, 200]               0
           Conv2d-64        [-1, 128, 200, 200]          32,896
             ReLU-65        [-1, 128, 200, 200]               0
           Conv2d-66        [-1, 128, 198, 198]         295,040
      BatchNorm2d-67        [-1, 128, 198, 198]             256
             ReLU-68        [-1, 128, 198, 198]               0
           Conv2d-69        [-1, 128, 196, 196]         147,584
      BatchNorm2d-70        [-1, 128, 196, 196]             256
             ReLU-71        [-1, 128, 196, 196]               0
     Double_convs-72        [-1, 128, 196, 196]               0
   Block_UpSample-73        [-1, 128, 196, 196]               0
         Upsample-74        [-1, 128, 392, 392]               0
           Conv2d-75         [-1, 64, 392, 392]           8,256
             ReLU-76         [-1, 64, 392, 392]               0
           Conv2d-77         [-1, 64, 390, 390]          73,792
      BatchNorm2d-78         [-1, 64, 390, 390]             128
             ReLU-79         [-1, 64, 390, 390]               0
           Conv2d-80         [-1, 64, 388, 388]          36,928
      BatchNorm2d-81         [-1, 64, 388, 388]             128
             ReLU-82         [-1, 64, 388, 388]               0
     Double_convs-83         [-1, 64, 388, 388]               0
   Block_UpSample-84         [-1, 64, 388, 388]               0
           Conv2d-85          [-1, 2, 388, 388]             130
================================================================
Total params: 28,953,474
Trainable params: 28,953,474
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.25
Forward/backward pass size (MB): 3864.11
Params size (MB): 110.45
Estimated Total Size (MB): 3975.81
----------------------------------------------------------------

        有问题可以提出来,我们共同学习!

参考博客:

图像分割之U-Net - 知乎 (zhihu.com)

U-Net原理分析与代码解读 - 知乎 (zhihu.com)

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值