【Python/Pytorch - 网络模型】-- 手把手搭建3D U-Net模型

在这里插入图片描述
文章目录

00 写在前面

通过3D U-Net代码学习,可以学习基于Pytorch的网络结构模块化编程,对于后续学习其他更复杂3D网络模型,有很大的帮助作用。

在01中,可以根据3D U-Net的网络结构(开头图片),进行模块化编程。包括卷积模块定义、上采样模块定义、下采样模块定义、输出卷积层定义、网络模型定义等。

在模型调试过程中,可以先通过简单测试代码,进行代码调试。

01 基于Pytorch版本的3D UNet代码

# 库函数调用
import torch
from torch import nn
import torch.nn.functional as F

import numpy as np

# from measure import Four_three


# 三维卷积块定义
class DoubleConv(nn.Module):
    """(Conv3D -> IN -> ReLU) * 2"""

    def __init__(self, in_channels, out_channels, num_groups = 8):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1,bias=True),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1,bias=True),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

# 下采样模块定义
class Down(nn.Module):

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.MaxPool3d(2,2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.encoder(x)

# 上采样模块定义
class Up(nn.Module):

    def __init__(self, in_channels, out_channels, trilinear = True):
        super().__init__()

        if trilinear:
            self.up = nn.Upsample(scale_factor = 2)
        else:
            self.up = nn.ConvTranspose3d(in_channels // 2, in_channels // 2, kernel_size = 2, stride = 2)

        self.conv = DoubleConv(in_channels, out_channels)
        self.downc = nn.Conv3d(in_channels, out_channels, kernel_size = 3, stride=1, padding=1, bias=True)
        self.downr = nn.ReLU(inplace=True)

    def forward(self, x1, x2):
        x1 = self.up(x1)

        diffZ = x2.size()[2] - x1.size()[2]
        diffY = x2.size()[3] - x1.size()[3]
        diffX = x2.size()[4] - x1.size()[4]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2, diffZ // 2, diffZ - diffZ // 2])

        x1 = self.downr(self.downc(x1))

        x = torch.cat([x2, x1], dim = 1)
        return self.conv(x)

# 输出卷积层定义
class Out(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size = 3, stride=1, padding=1)

    def forward(self, x):
        return self.conv(x)

# 3D-UNet模型定义
class 3DUNET(nn.Module):
    def __init__(self, in_channels=3, out_channels=1,n_channels=64):
        super().__init__()
        self.in_channels = in_channels
        self.n_channels = n_channels

        self.conv = DoubleConv(in_channels, n_channels)
        self.enc1 = Down(n_channels, 2 * n_channels)
        self.enc2 = Down(2 * n_channels, 4 * n_channels)
        self.enc3 = Down(4 * n_channels, 8 * n_channels)
        self.enc4 = Down(8 * n_channels, 16 * n_channels)

        self.dec1 = Up(16 * n_channels, 8 * n_channels)
        self.dec2 = Up(8 * n_channels, 4 * n_channels)
        self.dec3 = Up(4 * n_channels, 2*n_channels)
        self.dec4 = Up(2 * n_channels, n_channels)
        self.out = Out(n_channels, out_channels) #(1,4,128,128,n)


    def forward(self, x):
        # print('size of x:', x.shape)
        x1 = self.conv(x)
        # print('size of x1:', x1.shape)
        x2 = self.enc1(x1)
        # print('size of x2:', x2.shape)
        x3 = self.enc2(x2)
        # print('size of x3:', x3.shape)
        x4 = self.enc3(x3)
        # print('size of x4:', x4.shape)
        x5 = self.enc4(x4)
        # print('size of x5:', x5.shape)

        mask = self.dec1(x5, x4)
        # print('size of mask:', mask.shape)
        mask = self.dec2(mask, x3)
        # print('size of mask:', mask.shape)
        mask = self.dec3(mask, x2)
        # print('size of mask:', mask.shape)
        mask = self.dec4(mask, x1)
        # print('size of mask:', mask.shape)
        mask = self.out(mask)
        # print('size of mask:', mask.shape)

        return mask

# 测试代码
if __name__ == '__main__':
	input_channels = 4
	output_channels = 1
	x = torch.ones([16, 4, 16, 16,16])
	model = 3DUNET(input_channels, output_channels)
	print('model initialization finished!')
	f = model(x)
	print(f)

02 论文下载

3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation
arXiv: 3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation

  • 4
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值