U_net 网络(pytorch学习)

U_net 网络(pytorch学习)

 U_net是一个经典的图像分割网络,可以完成许多功能,在学习U_net网络后结合B站的视频尝试编写U_net代码,锻炼编程能力

一 U_net网络结构

  U_net网络的网络结构如下图所示:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-alsUIo1Z-1662831702017)(“C:\Users\23577\Desktop[2H[T27D8~{THTHD]]VPJD6.png”)]

网络模型代码

  步骤:

  • 先定义下采样网络既两个卷积

    import torch
    import torchvision
    from torch import nn
    
    
    class Double_conv(nn.Module):
        def __init__(self,in_channel, out_channel):
            super(Double_conv, self).__init__()
            """
            在这里使用卷积,保持图像尺寸不变,以便更好计算
            """
            self.layer = nn.Sequential(
                nn.Conv2d(in_channel,out_channel,kernel_size=3,padding=1,bias=False),
                nn.BatchNorm2d(out_channel),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_channel,out_channel,kernel_size=3,padding=1,bias=False),
                nn.BatchNorm2d(out_channel),
                nn.ReLU(inplace=True),
            )
        def forward(self,x):
            return self.layer(x)
    urn self.layer(x)
    
    
  • 定义U_net整个网络模型

    class U_NET(nn.Module):
        def __init__(self, in_channel,out_channel,features=[64,128,256,512]):
            super(U_NET, self).__init__()
            self.DOWN = nn.ModuleList()
            self.UP = nn.ModuleList()
            self.maxpool = nn.MaxPool2d(2)
            for feature in features:
                self.DOWN.append(Double_conv(in_channel,feature))
                in_channel = feature
    
            for feature in reversed(features):
                self.UP.append(nn.ConvTranspose2d(feature*2,feature,kernel_size=2,stride=2))
                self.UP.append(Double_conv(feature*2,feature))
    
            self.botten = Double_conv(features[-1],features[-1]*2)
            self.final_conv = nn.Conv2d(features[0],out_channel,kernel_size=1,padding=0)
    
        def forward(self,x):
            skip_connect=[]
            for idx in self.DOWN:
                x = idx(x)
                skip_connect.append(x)
                x = self.maxpool(x)
            x = self.botten(x)
            skip_connect = skip_connect[::-1]
    
            for idx in range(0,len(self.UP),2):
                x = self.UP[idx](x)
                """
                为了适用,任意尺寸的图片特征融合时为保证尺寸相同Resize一下
                """
                if x.shape != skip_connect[idx // 2].shape:
                    x = torchvision.transforms.Resize( skip_connect[idx // 2].shape[2:])(x)
                x = torch.cat((x,skip_connect[idx//2]),dim=1)
                x = self.UP[idx+1](x)
            return self.final_conv(x)
    
  • 测试结果

    if __name__ == "__main__":
        x = torch.randn(1,1,161,161)
        model = U_NET(in_channel=1,out_channel=1)
        y = model(x)
        print(y.shape)
        # -> torch.Size([1, 1, 161, 161])
    
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

theshy123333

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值