Unet网络搭建(Pytorch)

Unet是一个经典的语义分割网络,常常被用于医学影像的分割。在Unet的网络结构中,可以分为卷积模块,下采样模块以及上采样模块,详见下面的网络结构图:
在这里插入图片描述 在网络的搭建过程中,也是依照分为三大块这种思路进行搭建。话不多说,直接上代码:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

class conv_block(nn.Module):
    def __init__(self,in_c,out_c):
        super(conv_block,self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(in_c,out_c,kernel_size=(3,3),stride=1,padding=1,padding_mode='reflect'),
            nn.BatchNorm2d(out_c),
            nn.Dropout(0.3),
            nn.ReLU(inplace=True),
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(out_c, out_c, kernel_size=(3, 3), stride=1, padding=1, padding_mode='reflect',bias = False),
            nn.BatchNorm2d(out_c),
            nn.Dropout(0.3),
            nn.ReLU(inplace=True),
        )

    def forward(self,x):
        x = self.layer1(x)
        x = self.layer2(x)
        return x

class Downsample(nn.Module):
    def __init__(self,channel):
        super(Downsample, self).__init__()
        self.layer = nn.Sequential(
            nn.Conv2d(channel, channel, kernel_size=(3, 3), stride=2, padding=1,  bias=False),
            nn.BatchNorm2d(channel),
            nn.ReLU()
        )

    def forward(self,x):
        return self.layer(x)


class Upsample(nn.Module):
    def __init__(self,channel):
        super(Upsample, self).__init__()
        self.conv1 = nn.Conv2d(channel,channel//2,kernel_size=(1,1),stride=1)

    def forward(self,x,featuremap):
        x = F.interpolate(x,scale_factor=2,mode='nearest')
        x = self.conv1(x)
        x = torch.cat((x,featuremap),dim=1)
        return x

class UNET(nn.Module):
    def __init__(self,in_channel,out_channel):
        super(UNET, self).__init__()
        self.layer1 = conv_block(in_channel,out_channel)
        self.layer2 = Downsample(out_channel)
        self.layer3 = conv_block(out_channel,out_channel*2)
        self.layer4 = Downsample(out_channel*2)
        self.layer5 = conv_block(out_channel*2,out_channel*4)
        self.layer6 = Downsample(out_channel*4)
        self.layer7 = conv_block(out_channel*4,out_channel*8)
        self.layer8 = Downsample(out_channel*8)
        self.layer9 = conv_block(out_channel*8,out_channel*16)
        self.layer10 = Upsample(out_channel*16)
        self.layer11 = conv_block(out_channel*16,out_channel*8)
        self.layer12 = Upsample(out_channel*8)
        self.layer13 = conv_block(out_channel*8,out_channel*4)
        self.layer14 = Upsample(out_channel*4)
        self.layer15 = conv_block(out_channel*4,out_channel*2)
        self.layer16 = Upsample(out_channel*2)
        self.layer17 = conv_block(out_channel*2,out_channel)
        self.layer18 = nn.Conv2d(out_channel,3,kernel_size=(1,1),stride=1)
        self.act = nn.Sigmoid()

    def forward(self,x):
        x = self.layer1(x)
        f1 = x
        x = self.layer2(x)
        x = self.layer3(x)
        f2 = x
        x = self.layer4(x)
        x = self.layer5(x)
        f3 = x
        x = self.layer6(x)
        x = self.layer7(x)
        f4 = x
        x = self.layer8(x)
        x = self.layer9(x)
        x = self.layer10(x,f4)
        x = self.layer11(x)
        x = self.layer12(x,f3)
        x = self.layer13(x)
        x = self.layer14(x,f2)
        x = self.layer15(x)
        x = self.layer16(x,f1)
        x = self.layer17(x)
        x = self.layer18(x)
        return self.act(x)


if __name__ == '__main__':
    #device = 'cuda' if torch.cuda.is_available() else 'cpu'
    x = torch.randn(10,3,256,256)
    model = UNET(3,64)
    #if hasattr(torch.cuda, 'empty_cache'):
        #torch.cuda.empty_cache()

    x = model(x)
    print(x.size())

    wiriter = SummaryWriter('log1')
    wiriter.add_graph(model,x)

最后,我们可以使用tensorboard查看网络结构:
在这里插入图片描述

U-Net是一种深度学习模型,最初用于生物医学图像分割,但它也可以应用于图像去噪任务。在PyTorch中复现U-Net,你可以按照以下步骤操作: 1. **安装依赖**:首先确保已经安装了PyTorch及其相关的库,如torchvision。如果需要,可以运行`pip install torch torchvision`. 2. **网络结构搭建**:创建一个U-Net模型的核心部分,它包括编码器(逐渐降低分辨率,提取特征)和解码器(逐步增加分辨率,恢复细节)。可以参考论文《Image Segmentation through Deep Learning》中的架构。 ```python import torch.nn as nn from torch.nn import Conv2d, MaxPool2d, UpSample class UNetBlock(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1): super(UNetBlock, self).__init__() self.encoder = nn.Sequential( Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding), nn.ReLU(), Conv2d(out_channels, out_channels, kernel_size, stride=stride, padding=padding) ) self.decoder = nn.Sequential( nn.ConvTranspose2d(out_channels, out_channels, kernel_size, stride=stride, padding=padding), nn.ReLU(), nn.Conv2d(out_channels, out_channels, kernel_size, stride=stride, padding=padding) ) def forward(self, x): skip_connection = x x = self.encoder(x) x = self.decoder(x) return torch.cat((x, skip_connection), dim=1) # 构建完整的U-Net模型 def create_unet(input_channels, num_classes): unet = nn.Sequential( nn.Conv2d(input_channels, 64, 3, padding=1), nn.MaxPool2d(2, 2), UNetBlock(64, 128), nn.MaxPool2d(2, 2), UNetBlock(128, 256), nn.MaxPool2d(2, 2), UNetBlock(256, 512), nn.MaxPool2d(2, 2), UNetBlock(512, 1024), nn.Upsample(scale_factor=2), UNetBlock(1024, 512), nn.Upsample(scale_factor=2), UNetBlock(512, 256), nn.Upsample(scale_factor=2), UNetBlock(256, 128), nn.Upsample(scale_factor=2), nn.Conv2d(128, num_classes, 1) ) return unet ``` 3. **训练和应用**:准备噪声图像数据、对应干净图像的数据集,然后定义损失函数(如MSE或SSIM)、优化器,并开始训练。训练完成后,对新的噪声图像进行前向传播以获得去噪后的结果。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值