UNet++网络复现,包括深度监督

注意

代码复现的时候,遵循从左下到右上的顺序,这样思路就会更清楚。UNet++原论文的图解给的详细信息不多,建议先将UNet复现之后,UNet++就很容易上手了。

代码:

# coding:utf8
from modulefinder import Module

import torch
from torch import nn


class Conv(nn.Module):
    def __init__(self, in_channels, out_channels, pre_BachNorm=False):
        super(Conv, self).__init__()
        if pre_BachNorm:
            self.conv = nn.Sequential(
                nn.BatchNorm2d(in_channels),
                nn.SiLU(),
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1),
                nn.BatchNorm2d(out_channels),
                nn.SiLU(),
                nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1),
            )
        else:
            self.conv = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1),
                nn.BatchNorm2d(out_channels),
                nn.SiLU(),
                nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1),
                nn.BatchNorm2d(out_channels),
                nn.SiLU(),
            )

    def forward(self, x):
        return self.conv(x)


# 下采样
class Down_Conv(nn.Module):
    def __init__(self, channels):
        super(Down_Conv, self).__init__()
        # self.down_conv = nn.Sequential(
        #     # 原始的只有一个Maxpool,可以在maxpool后加一个卷积层,进行特征融合
        #     nn.MaxPool2d(2),
        #     nn.Conv2d(channels),
        #     nn.BatchNorm2d(channels),
        #     nn.SiLU(inplace=True)
        # )

        # 为了融合更多的信息,我觉得还是卷积比较好
        self.down_conv = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(channels),
            nn.SiLU(),
        )

    def forward(self, x):
        return self.down_conv(x)


# 上采样
class Up_Conv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Up_Conv, self).__init__()
        self.up = nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size=3,
                stride=2,
                padding=1,
                output_padding=1),
            nn.BatchNorm2d(out_channels),
            nn.SiLU()
        )

    def forward(self, x):
        return self.up(x)


class UnetPulsPuls(nn.Module):
    def __init__(self, supervised):
        super(UnetPulsPuls, self).__init__()

        self.supervised = supervised

        self.stage1 = Conv(3, 64, pre_BachNorm=True)
        self.stage1_down = Down_Conv(64)
        self.stage2 = Conv(64, 128, True)
        self.stage2_up = Up_Conv(128, 64)
        self.stage2_down = Down_Conv(128)
        self.stage3 = Conv(128, 256, True)
        self.stage3_up = Up_Conv(256, 128)
        self.stage3_down = Down_Conv(256)
        self.stage4 = Conv(256, 512, True)
        self.stage4_up = Up_Conv(512, 256)
        self.stage4_down = Down_Conv(512)
        self.stage5 = Conv(512, 1024, True)
        self.stage5_up = Up_Conv(1024, 512)

        self.x_0_1 = Conv(64 * 2, 64)
        self.x_0_2 = Conv(64 * 3, 64)
        self.x_0_3 = Conv(64 * 4, 64)
        self.x_0_4 = Conv(64 * 5, 64)

        self.x_1_1 = Conv(128 * 2, 128)
        self.x_1_1_up = Up_Conv(128, 64)
        self.x_1_2 = Conv(128 * 3, 128)
        self.x_1_2_up = Up_Conv(128, 64)
        self.x_1_3 = Conv(128 * 4, 128)
        self.x_1_3_up = Up_Conv(128, 64)

        self.x_2_1 = Conv(256 * 2, 256)
        self.x_2_1_up = Up_Conv(256, 128)
        self.x_2_2 = Conv(256 * 3, 256)
        self.x_2_2_up = Up_Conv(256, 128)

        self.x_3_1 = Conv(512 * 2, 512)
        self.x_3_1_up = Up_Conv(512, 256)

        self.end = nn.Sequential(
            nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        x_0_0 = self.stage1(x)
        x_1_0 = self.stage2(self.stage1_down(x_0_0))
        x_2_0 = self.stage3(self.stage2_down(x_1_0))
        x_3_0 = self.stage4(self.stage3_down(x_2_0))
        x_4_0 = self.stage5(self.stage4_down(x_3_0))


        x_0_1 = self.x_0_1(torch.cat([x_0_0, self.stage2_up(x_1_0)], dim=1))
        x_1_1 = self.x_1_1(torch.cat([x_1_0, self.stage3_up(x_2_0)], dim=1))
        x_2_1 = self.x_2_1(torch.cat([x_2_0, self.stage4_up(x_3_0)], dim=1))
        x_3_1 = self.x_3_1(torch.cat([x_3_0, self.stage5_up(x_4_0)], dim=1))

        x_0_2 = self.x_0_2(torch.cat([x_0_0, x_0_1, self.x_1_1_up(x_1_1)], dim=1))
        x_1_2 = self.x_1_2(torch.cat([x_1_0, x_1_1, self.x_2_1_up(x_2_1)], dim=1))
        x_2_2 = self.x_2_2(torch.cat([x_2_0, x_2_1, self.x_3_1_up(x_3_1)], dim=1))

        x_0_3 = self.x_0_3(torch.cat([x_0_0, x_0_1, x_0_2, self.x_1_2_up(x_1_2)], dim=1))
        x_1_3 = self.x_1_3(torch.cat([x_1_0, x_1_1, x_1_2, self.x_2_2_up(x_2_2)], dim=1))

        x_0_4 = self.x_0_4(torch.cat([x_0_0, x_0_1, x_0_2, x_0_3, self.x_1_3_up(x_1_3)], dim=1))

        if self.supervised:
            return self.end(x_0_1), self.end(x_0_2), self.end(x_0_3), self.end(x_0_4)
        else:
            return self.end(x_0_4)


if __name__ == '__main__':
    xx = torch.randn((1, 3, 640, 640))
    mask = torch.rand(1, 3, 640, 640)

    model = UnetPulsPuls(supervised=True)

    # for name, layer in model.named_children():
    #     xx = layer(xx)
    #     print(name, xx.shape)

    x_0_1, x_0_2, x_0_3, x_0_4 = model(xx)
    l1 = mask - x_0_1
    l2 = mask - x_0_2
    l3 = mask - x_0_3
    l4 = mask - x_0_4
    l = l1 + l2 + l3 + l4


### UNet+++ 模型代码实现与复现教程 #### GitHub资源链接 对于希望获取UNet+++模型的代码实现或复现教程的人而言,有多个GitHub仓库提供了详细的资料和支持。一个官方推荐的PyTorch实现可以在bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets找到[^1]。此项目不仅包含了基础的U-Net架构,还扩展到了更复杂的变体如RCNN-U-net, Attention U-net等。 另一个重要的资源来自ZJUGiveLab/UNet-Version,该库专注于提供多种版本的U-Net及其衍生模型(包括UNet+++, 注意力机制增强版)的具体实现方法[^5]。这些资源非常适合那些想要深入了解并尝试不同改进方案的研究人员和技术爱好者们。 #### 实战指南与解释文档 为了帮助理解如何实际操作以及背后的设计理念,在CSDN博客上有一系列文章深入浅出地介绍了UNet+++的关键概念和重点代码片段。通过阅读这些材料,读者可以获得关于为什么某些设计决策被采纳的第一手见解,并学习到最佳实践技巧来优化自己的项目开发过程。 #### TensorFlow/Keras实现 除了上述提到的主要针对PyTorch框架的内容外,MrGiovanni/UNetPlusPlus则是一个专门为TensorFlow用户提供服务的开源项目[^2]。它实现了原始论文中的UNet++算法,并附带了详尽的例子说明怎样训练模型处理医学图像数据集等问题。虽然这不是严格意义上的UNet+++,但对于熟悉Keras/TensorFlow环境的人来说仍然是非常有价值的参考资料之一。 ```python import torch.nn as nn class UNetBlock(nn.Module): """Basic block used within the UNet architecture.""" def __init__(self, in_channels, out_channels): super().__init__() self.conv = nn.Sequential( nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3), nn.ReLU(inplace=True), nn.BatchNorm2d(out_channels)) def forward(self, x): return self.conv(x) def build_unetppp_model(): """ Constructs a simplified version of the UNet+++ model. Note that this is not an exact implementation but serves to illustrate key components. For full details please refer to official repositories or publications. """ pass # Placeholder function body; actual construction would involve stacking multiple blocks with skip connections etc. if __name__ == "__main__": # Example usage when running script directly (not recommended for production code) from torchvision import models net = build_unetppp_model() print(net) ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Jay_Mapp

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值