Pytorch搭建U-Net网络

U-Net: Convolutional Networks for Biomedical Image Segmentation在这里插入图片描述

import torch.nn as nn
import torch
from torch import autograd
from torchsummary import summary

class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=0),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=0),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, input):
        return self.conv(input)

class Unet(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Unet, self).__init__()
        self.conv1 = DoubleConv(in_ch, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.conv3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.conv4 = DoubleConv(256, 512)
        self.pool4 = nn.MaxPool2d(2)
        self.conv5 = DoubleConv(512, 1024)
        # 逆卷积,也可以使用上采样
        self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.conv6 = DoubleConv(1024, 512)
        self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.conv7 = DoubleConv(512, 256)
        self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv8 = DoubleConv(256, 128)
        self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.conv9 = DoubleConv(128, 64)
        self.conv10 = nn.Conv2d(64, out_ch, 1)

    def forward(self, x):
        c1 = self.conv1(x)
        crop1 = c1[:,:,88:480,88:480]
        p1 = self.pool1(c1)
        c2 = self.conv2(p1)
        crop2 = c2[:,:,40:240,40:240]
        p2 = self.pool2(c2)
        c3 = self.conv3(p2)
        crop3 = c3[:,:,16:120,16:120]
        p3 = self.pool3(c3)
        c4 = self.conv4(p3)
        crop4 = c4[:,:,4:60,4:60]
        p4 = self.pool4(c4)
        c5 = self.conv5(p4)
        up_6 = self.up6(c5)
        merge6 = torch.cat([up_6, crop4], dim=1)
        c6 = self.conv6(merge6)
        up_7 = self.up7(c6)
        merge7 = torch.cat([up_7, crop3], dim=1)
        c7 = self.conv7(merge7)
        up_8 = self.up8(c7)
        merge8 = torch.cat([up_8, crop2], dim=1)
        c8 = self.conv8(merge8)
        up_9 = self.up9(c8)
        merge9 = torch.cat([up_9, crop1], dim=1)
        c9 = self.conv9(merge9)
        c10 = self.conv10(c9)
        out = nn.Sigmoid()(c10)
        return out


if __name__=="__main__":
    test_input=torch.rand(1, 1, 572, 572)
    model=Unet(in_ch=1, out_ch=2)
    summary(model, (1,572,572))
    ouput=model(test_input)
    print(ouput.size())
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 570, 570]             640
       BatchNorm2d-2         [-1, 64, 570, 570]             128
              ReLU-3         [-1, 64, 570, 570]               0
            Conv2d-4         [-1, 64, 568, 568]          36,928
       BatchNorm2d-5         [-1, 64, 568, 568]             128
              ReLU-6         [-1, 64, 568, 568]               0
        DoubleConv-7         [-1, 64, 568, 568]               0
         MaxPool2d-8         [-1, 64, 284, 284]               0
            Conv2d-9        [-1, 128, 282, 282]          73,856
      BatchNorm2d-10        [-1, 128, 282, 282]             256
             ReLU-11        [-1, 128, 282, 282]               0
           Conv2d-12        [-1, 128, 280, 280]         147,584
      BatchNorm2d-13        [-1, 128, 280, 280]             256
             ReLU-14        [-1, 128, 280, 280]               0
       DoubleConv-15        [-1, 128, 280, 280]               0
        MaxPool2d-16        [-1, 128, 140, 140]               0
           Conv2d-17        [-1, 256, 138, 138]         295,168
      BatchNorm2d-18        [-1, 256, 138, 138]             512
             ReLU-19        [-1, 256, 138, 138]               0
           Conv2d-20        [-1, 256, 136, 136]         590,080
      BatchNorm2d-21        [-1, 256, 136, 136]             512
             ReLU-22        [-1, 256, 136, 136]               0
       DoubleConv-23        [-1, 256, 136, 136]               0
        MaxPool2d-24          [-1, 256, 68, 68]               0
           Conv2d-25          [-1, 512, 66, 66]       1,180,160
      BatchNorm2d-26          [-1, 512, 66, 66]           1,024
             ReLU-27          [-1, 512, 66, 66]               0
           Conv2d-28          [-1, 512, 64, 64]       2,359,808
      BatchNorm2d-29          [-1, 512, 64, 64]           1,024
             ReLU-30          [-1, 512, 64, 64]               0
       DoubleConv-31          [-1, 512, 64, 64]               0
        MaxPool2d-32          [-1, 512, 32, 32]               0
           Conv2d-33         [-1, 1024, 30, 30]       4,719,616
      BatchNorm2d-34         [-1, 1024, 30, 30]           2,048
             ReLU-35         [-1, 1024, 30, 30]               0
           Conv2d-36         [-1, 1024, 28, 28]       9,438,208
      BatchNorm2d-37         [-1, 1024, 28, 28]           2,048
             ReLU-38         [-1, 1024, 28, 28]               0
       DoubleConv-39         [-1, 1024, 28, 28]               0
  ConvTranspose2d-40          [-1, 512, 56, 56]       2,097,664
           Conv2d-41          [-1, 512, 54, 54]       4,719,104
      BatchNorm2d-42          [-1, 512, 54, 54]           1,024
             ReLU-43          [-1, 512, 54, 54]               0
           Conv2d-44          [-1, 512, 52, 52]       2,359,808
      BatchNorm2d-45          [-1, 512, 52, 52]           1,024
             ReLU-46          [-1, 512, 52, 52]               0
       DoubleConv-47          [-1, 512, 52, 52]               0
  ConvTranspose2d-48        [-1, 256, 104, 104]         524,544
           Conv2d-49        [-1, 256, 102, 102]       1,179,904
      BatchNorm2d-50        [-1, 256, 102, 102]             512
             ReLU-51        [-1, 256, 102, 102]               0
           Conv2d-52        [-1, 256, 100, 100]         590,080
      BatchNorm2d-53        [-1, 256, 100, 100]             512
             ReLU-54        [-1, 256, 100, 100]               0
       DoubleConv-55        [-1, 256, 100, 100]               0
  ConvTranspose2d-56        [-1, 128, 200, 200]         131,200
           Conv2d-57        [-1, 128, 198, 198]         295,040
      BatchNorm2d-58        [-1, 128, 198, 198]             256
             ReLU-59        [-1, 128, 198, 198]               0
           Conv2d-60        [-1, 128, 196, 196]         147,584
      BatchNorm2d-61        [-1, 128, 196, 196]             256
             ReLU-62        [-1, 128, 196, 196]               0
       DoubleConv-63        [-1, 128, 196, 196]               0
  ConvTranspose2d-64         [-1, 64, 392, 392]          32,832
           Conv2d-65         [-1, 64, 390, 390]          73,792
      BatchNorm2d-66         [-1, 64, 390, 390]             128
             ReLU-67         [-1, 64, 390, 390]               0
           Conv2d-68         [-1, 64, 388, 388]          36,928
      BatchNorm2d-69         [-1, 64, 388, 388]             128
             ReLU-70         [-1, 64, 388, 388]               0
       DoubleConv-71         [-1, 64, 388, 388]               0
           Conv2d-72          [-1, 2, 388, 388]             130
================================================================
Total params: 31,042,434
Trainable params: 31,042,434
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.25
Forward/backward pass size (MB): 3280.59
Params size (MB): 118.42
Estimated Total Size (MB): 3400.26
----------------------------------------------------------------
torch.Size([1, 2, 388, 388])
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
### 回答1: 我可以给你一些关于使用PyTorch搭建HR-Net的建议:1. 先了解HR-Net的基本架构,然后熟悉PyTorch提供的相关API;2. 使用PyTorch搭建HR-Net的基本结构,并设置相应的参数;3. 调整参数,以获得最佳模型;4. 在测试集上进行验证,以确保模型具有良好的性能。 ### 回答2: HR-Net是一种用于人体关键点检测的神经网络模型,我们可以使用PyTorch搭建HR-Net。在搭建HR-Net之前,我们需要进行一些准备工作。 首先,我们需要安装PyTorch库。可以通过以下命令安装PyTorch: ``` pip install torch torchvision ``` 然后,我们需要下载HR-Net的代码和预训练的权重。可以在GitHub上找到HR-Net的代码库,并下载。下载完成后,解压缩代码包。 接下来,我们可以在PyTorch中定义HR-Net的网络结构。HR-Net基于两个主要的网络模块:骨干网络和多分支特征融合模块。 在骨干网络中,我们可以选择使用一些常见的神经网络模型,如ResNet、AlexNet等。我们可以在PyTorch中创建这些骨干网络,并将其作为HR-Net的输入。 在多分支特征融合模块中,我们通过将不同尺度的特征图进行融合,来提高人体关键点检测的准确性。我们可以在PyTorch实现这个多分支特征融合模块,并将其添加到HR-Net中。 最后,我们可以加载HR-Net的预训练权重,并将其用于人体关键点检测任务。我们可以使用PyTorch的数据加载器来加载训练数据,并使用预定义的损失函数和优化器来训练模型。 使用PyTorch搭建HR-Net可以使我们更轻松地实现人体关键点检测任务,并利用PyTorch的丰富功能来优化和扩展HR-Net模型。 ### 回答3: 使用PyTorch搭建HR-Net可以通过以下步骤完成: 1. 安装PyTorch:首先要在计算机上安装PyTorch库,可以通过在终端或命令提示符中运行适用于您的系统的安装命令来完成。 2. 导入必要的库:在Python脚本中,导入PyTorch以及其他必要的库,如numpy、matplotlib等。 3. 构建HR-Net模型:HR-Net是一种深度卷积神经网络体系结构,它具有多个分支并行处理低分辨率和高分辨率特征。可以使用PyTorch的nn.Module类构建HR-Net模型,并定义需要的卷积、池化、Batch Normalization等操作层。 4. 定义前向传播函数:在HR-Net模型类中,定义一个前向传播函数,该函数定义了输入数据通过模型时的计算流程。在这个函数中,可以将输入数据传递到HR-Net的各个分支,然后将其联合起来形成最终的输出。 5. 定义损失函数和优化器:为了训练HR-Net模型,需要定义一个损失函数来度量模型的输出和真实标签之间的差距,并选择一个优化器来更新模型的参数。PyTorch提供了各种损失函数和优化器的选项,可以根据具体问题的需求选择合适的函数和优化器。 6. 训练模型:使用已定义的损失函数和优化器,在训练数据上进行模型的训练。通过将训练数据输入到HR-Net模型中,并计算其输出与真实标签之间的损失,根据这个损失来更新模型的参数。 7. 测试模型:在训练完成后,可以使用测试数据来评估模型的性能。将测试数据输入到HR-Net模型中,获取模型的预测输出,并与真实标签进行比较,可以计算一些评价指标,例如准确率、精确率、召回率等。 8. 调整模型和超参数:根据测试结果,可以对模型和超参数进行调整,以优化模型的性能。可以更改模型的结构、增加或减少训练数据,调整学习率等。 9. 保存和加载模型:在训练完成后,可以将模型保存到磁盘上,以便后续使用。同时,也可以从保存的模型文件中加载已经训练好的模型,并在新的数据上进行预测。 以上是使用PyTorch搭建HR-Net的一般步骤,具体实现过程中可以根据需要进行进一步的细化和改进。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值