神经网络基础模型【2】:U-Net及其各种变体

本文介绍了U-Net网络结构,一种常用于医学图像分割的深度学习模型,其金字塔结构在低级视觉任务如超分和降噪中也发挥作用。通过PyTorch实现了一个简单的U-Net实例,并提到了后续改进版本如使用残差块以提升性能。
摘要由CSDN通过智能技术生成

最早被提出应用于医学图像分割,后扩展至通用分割,后面在low-level领域也发挥着巨大的作用。

U-Net

友情链接:U-Net: Convolutional Networks for Biomedical Image Segmentation)

网络原理

U-Net 网络结构

U-Net 的网络结构其实很简单,类似于传统图像处理中的金字塔结构。对输入进行多次的 conv+relu 特征提取,然后进行 maxpooling 下采样,扩大感受野的同时减小特征图尺寸,循环多次后得到上图中最下面一层的特征图,然后进行 upsample+conv+concat,再对上采样之后的特征图进行 conv+relu 操作,和前面一样重复多次,便得到了最后的结果。

如今的 U-Net 不光是在分割领域,在笔者所从事的 low-level 视觉中也得到了广泛的应用,比如超分、降噪等,他们的一个共同特征是,输入和输出往往是相同尺寸的,image2image 的任务。

pytorch 代码

import torch
import torch.nn as nn


def double_conv_relu(n_in, n_out):
    block = nn.Sequential(
        nn.Conv2d(n_in, n_out, 3, 1, 1),
        nn.ReLU(),
        nn.Conv2d(n_out, n_out, 3, 1, 1),
        nn.ReLU(),
    )

    return block


class UNet(nn.Module):
    def __init__(self) -> None:
        super(UNet, self).__init__()

        self.conv1 = double_conv_relu(1, 64)
        self.down1 = nn.MaxPool2d(2, 2)

        self.conv2 = double_conv_relu(64, 128)
        self.down2 = nn.MaxPool2d(2, 2)

        self.conv3 = double_conv_relu(128, 256)
        self.down3 = nn.MaxPool2d(2, 2)

        self.conv4 = double_conv_relu(256, 512)
        self.down4 = nn.MaxPool2d(2, 2)

        self.conv5 = double_conv_relu(512, 1024)

        self.up1 = nn.ConvTranspose2d(1024, 512, 4, 2, 1)
        self.conv6 = double_conv_relu(1024, 512)

        self.up2 = nn.ConvTranspose2d(512, 256, 4, 2, 1)
        self.conv7 = double_conv_relu(512, 256)

        self.up3 = nn.ConvTranspose2d(256, 128, 4, 2, 1)
        self.conv8 = double_conv_relu(256, 128)

        self.up4 = nn.ConvTranspose2d(128, 64, 4, 2, 1)
        self.conv9 = double_conv_relu(128, 64)

        self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)

    def forward(self, x):
        feat1 = self.conv1(x)
        feat2 = self.conv2(self.down1(feat1))
        feat3 = self.conv3(self.down2(feat2))
        feat4 = self.conv4(self.down3(feat3))
        feat5 = self.conv5(self.down4(feat4))
        feat6 = self.conv6(torch.cat((feat4, self.up1(feat5)), dim=1))
        feat7 = self.conv7(torch.cat((feat3, self.up2(feat6)), dim=1))
        feat8 = self.conv8(torch.cat((feat2, self.up3(feat7)), dim=1))
        feat9 = self.conv9(torch.cat((feat1, self.up4(feat8)), dim=1))
        out = self.conv_last(feat9)

        return out


if __name__ == "__main__":
    x = torch.rand(1, 1, 256, 256)

    net = UNet(3, 64, 3)

    y = net(x)
    print(y.shape)

自己写的,比较粗糙,主要看一下网络结构。因为 U-Net 提出是在 2015 年,很早了,后续在应用的时候会把原本代码中的 conv+relu 的结构换成残差块的结构,可以取得更好的效果。

UNet++

友情链接:UNet++: A Nested U-Net Architecture for Medical Image Segmentation

nnUNet

友情链接:nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值