U-net 代码手搓

 用于记录初次手搓U-net遇见的一些问题。

(放上最经典的U-net)

import torch
from torch import nn
from torchsummary import summary
import torch.nn.functional as F

## 本U-Net与经典的有所不同,本文的卷积层都添加的一层padding,故输出大小不发生改变。
## 需要分割几块,就在最后的1*1的卷积设置对应的个数(本文设置为2)

#2卷积
class Conv(nn.Module):        
    def __init__(self,in_channel,out_channel):
        super(Conv,self).__init__()
        self.Relu=nn.ReLU()
        self.con1=nn.Conv2d(in_channels=in_channel,out_channels=out_channel,kernel_size=3,padding=1)
        self.con2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3,padding=1)
    def forward(self,x):
        x=self.Relu(self.con1(x))
        x = self.Relu(self.con2(x))

        return x

#下采样
class Down(nn.Module):
    def __init__(self,Conv):
        super(Down,self).__init__()

        self.b1=nn.Sequential(
            Conv(1,64)
        )

        self.b2=nn.Sequential(
            nn.MaxPool2d(kernel_size=2, stride=2),
            Conv(64,128)
        )

        self.b3 = nn.Sequential(
            nn.MaxPool2d(kernel_size=2, stride=2),
            Conv(128, 256),
        )

        self.b4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=2, stride=2),
            Conv(256, 512),
        )

        self.b5 = nn.Sequential(
            nn.MaxPool2d(kernel_size=2, stride=2),
            Conv(512, 1024),
        )


    def forward(self,x):
        x1= self.b1(x)
        x2=self.b2(x1)
        x3 = self.b3(x2)
        x4 = self.b4(x3)
        x5 = self.b5(x4)

        return x1,x2,x3,x4,x5
#上采样
class Up(nn.Module):
    def __init__(self,Conv,Down):
        super(Up,self).__init__()
        self.up_sample = lambda x: F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        self.b1 = Conv(1024,512)
        self.b2 = Conv(512, 256)
        self.b3 = Conv(256, 128)
        self.b4 = Conv(128, 64)
        self.con1=nn.Conv2d(in_channels=64,out_channels=2,kernel_size=1)



    def forward(self,x):
        x1, x2, x3, x4, x5=Down(Conv)(x)
        y1=self.up_sample(x5)
        y1 = nn.Conv2d(1024, 512, kernel_size=1)(y1)
        y1 = self.b1(torch.cat((x4, y1), dim=1))

        y2 = self.up_sample(y1)
        y2 = nn.Conv2d(512, 256, kernel_size=1)(y2)
        y2 = self.b2(torch.cat((x3, y2), dim=1))

        y3 = self.up_sample(y2)
        y3 = nn.Conv2d(256, 128, kernel_size=1)(y3)
        y3 = self.b3(torch.cat((x2, y3), dim=1))

        y4 = self.up_sample(y3)
        y4 = nn.Conv2d(128, 64, kernel_size=1)(y4)
        y4 = self.b4(torch.cat((x1, y4), dim=1))

        y5= self.con1(y4)

        return y5







if __name__ == "__main__":
    device= torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)
    model = Up(Conv,Down).to(device)
    print(summary(model,(1,224,224)))  #这个函数可以通过输入图片大小,展现每一层图片处理时的大小

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值