自写的Unet,UnetPP,Unet3P网络

如有不足,欢迎指正。

Unet如下:

from torch import nn
import torch
import torch.nn.functional as F

def contracting_block(in_channels, out_channels):
    block = nn.Sequential(
        nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.BatchNorm2d(out_channels),

        nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.BatchNorm2d(out_channels),
    
    )
    return block


class expansive_block(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels):
        super(expansive_block, self).__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.block = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=mid_channels, kernel_size=3, padding=1),  # 添加填充
            nn.ReLU(),
            nn.BatchNorm2d(mid_channels),

            nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=3, padding=1),  # 添加填充
            nn.ReLU(),
            nn.BatchNorm2d(out_channels),
        )
    
    def forward(self, e, d):
        d = self.up(d)
        diffY = e.size()[2] - d.size()[2]
        diffX = e.size()[3] - d.size()[3]
        d = F.pad(d, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])  # 对上采样后的张量进行填充
        cat = torch.cat([e, d], dim=1)
        out = self.block(cat)
        return out



def final_block(in_channels, out_channels):
    block = nn.Sequential(
        nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.BatchNorm2d(out_channels),

    )
    return block


class Unet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Unet, self).__init__()

        self.conv_encode1 = contracting_block(in_channels=in_channels, out_channels=64)
        self.conv_pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv_encode2 = contracting_block(in_channels=64, out_channels=128)
        self.conv_pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv_encode3 = contracting_block(in_channels=128, out_channels=256)
        self.conv_pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv_encode4 = contracting_block(in_channels=256, out_channels=512)
        self.conv_pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv_encode5 = contracting_block(in_channels=512, out_channels=1024)
        self.conv_pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
    
        self.bottleneck = nn.Sequential(
            nn.Conv2d(kernel_size=3, in_channels=1024, out_channels=2048),
            nn.ReLU(),
            nn.BatchNorm2d(2048),
    
            nn.Conv2d(kernel_size=3, in_channels=2048, out_channels=2048),
            nn.ReLU(),
            nn.BatchNorm2d(2048),
    
        )
    
        self.conv_decode5 = expansive_block(2048, 1024, 1024)
        self.conv_decode4 = expansive_block(1024, 512, 512)
        self.conv_decode3 = expansive_block(512, 256, 256)
        self.conv_decode2 = expansive_block(256, 128, 128)
        self.conv_decode1 = expansive_block(128, 64, 64)
    
        self.final_layer = final_block(64, out_channels)
    
    def forward(self, x):
        encode_block1 = self.conv_encode1(x)
        pool1 = self.conv_pool1(encode_block1)
        encode_block2 = self.conv_encode2(pool1)
        pool2 = self.conv_pool2(encode_block2)
        encode_block3 = self.conv_encode3(pool2)
        pool3 = self.conv_pool3(encode_block3)
        encode_block4 = self.conv_encode4(pool3)
        pool4 = self.conv_pool4(encode_block4)
        encode_block5 = self.conv_encode5(pool4)
        pool5 = self.conv_pool5(encode_block5)
        bridge = self.bottleneck(pool5)
        decoder5 = self.conv_decode5(encode_block5, bridge)
        decoder4 = self.conv_decode4(encode_block4, decoder5)
        decoder3 = self.conv_decode3(encode_block3, decoder4)
        decoder2 = self.conv_decode2(encode_block2, decoder3)
        decoder1 = self.conv_decode1(encode_block1, decoder2)
    
        final_layer = self.final_layer(decoder1)
        return final_layer

Unet2P如下:

from torch import nn
import torch
import torch.nn.functional as F


def contracting_block(in_channels, out_channels):
    block = nn.Sequential(
        nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.BatchNorm2d(out_channels),

        nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.BatchNorm2d(out_channels),
    
    )
    return block


def up(in_channels):
    up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=3, stride=2, padding=1, output_padding=1)
    return up


def final_block(in_channels, out_channels):
    block = nn.Sequential(
        nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.BatchNorm2d(out_channels),

    )
    return block


class Unet2P_(nn.Module):
    def __init__(self, in_channels, out_channels, deep_supervision=False):
        super(Unet2P_, self).__init__()

        self.deep_supervision = deep_supervision
        self.pool = nn.MaxPool2d(2, 2)
        self.up0 = up(64)
        self.up1 = up(128)
        self.up2 = up(256)
        self.up3 = up(512)
        self.up4 = up(1024)
        self.up5 = up(2048)
    
        self.conv1 = contracting_block(in_channels, 64)
        self.conv2 = contracting_block(64, 128)
        self.conv3 = contracting_block(128, 256)
        self.conv4 = contracting_block(256, 512)
    
        self.bottleneck = nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(1024),
    
            nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(1024),
    
        )
    
        self.final_layer = final_block(64, out_channels)
    
        self.con01 = contracting_block(128, 64)
        self.con11 = contracting_block(256, 128)
        self.con21 = contracting_block(512, 256)
        self.con31 = contracting_block(1024, 512)
    
        self.con02 = contracting_block(192, 64)
        self.con12 = contracting_block(128 * 3, 128)
        self.con22 = contracting_block(256 * 3, 256)
    
        self.con03 = contracting_block(64 * 4, 64)
        self.con13 = contracting_block(128 * 4, 128)
    
        self.con04 = contracting_block(64 * 5, 64)
    
        if self.deep_supervision:
            self.final1 = nn.Conv2d(64, out_channels, kernel_size=3, padding=1)
            self.final2 = nn.Conv2d(64, out_channels, kernel_size=3, padding=1)
            self.final3 = nn.Conv2d(64, out_channels, kernel_size=3, padding=1)
            self.final4 = nn.Conv2d(64, out_channels, kernel_size=3, padding=1)
        else:
            pass
    
    def forward(self, x):
        x0_0 = self.conv1(x)
    
        x1_0 = self.conv2(self.pool(x0_0))
    
        x2_0 = self.conv3(self.pool(x1_0))
    
        x3_0 = self.conv4(self.pool(x2_0))
    
        x4_0 = self.bottleneck(self.pool(x3_0))
    
        x0_1 = self.con01(torch.cat([x0_0, self.up1(x1_0)], 1))
    
        x1_1 = self.con11(torch.cat([x1_0, self.up2(x2_0)], 1))
    
        x2_1 = self.con21(torch.cat([x2_0, self.up3(x3_0)], 1))
    
        x3_1 = self.con31(torch.cat([x3_0, self.up4(x4_0)], 1))
    
        x0_2 = self.con02(torch.cat([x0_0, x0_1, self.up1(x1_1)], 1))
    
        x1_2 = self.con12(torch.cat([x1_0, x1_1, self.up2(x2_1)], 1))
    
        x2_2 = self.con22(torch.cat([x2_0, x2_1, self.up3(x3_1)], 1))
    
        x0_3 = self.con03(torch.cat([x0_0, x0_1, x0_2, self.up1(x1_2)], 1))
    
        x1_3 = self.con13(torch.cat([x1_0, x1_1, x1_2, self.up2(x2_2)], 1))
    
        x0_4 = self.con04(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up1(x1_3)], 1))
    
        if self.deep_supervision:
            output1 = self.final1(x0_1)
            output2 = self.final2(x0_2)
            output3 = self.final3(x0_3)
            output4 = self.final4(x0_4)
            return [output1, output2, output3, output4]
    
        else:
            final_ = self.final_layer(x0_4)
            return final_

Unet3P如下:

from torch import nn
import torch
import torch.nn.functional as F

def contracting_block(in_channels, out_channels):
    block = nn.Sequential(
        nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.BatchNorm2d(out_channels),

        nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.BatchNorm2d(out_channels),

    )
    return block

def up(in_channels):
    up_ = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=3, stride=2, padding=1, output_padding=1)
    return up_

def up1(in_channels):
    up_ = nn.ConvTranspose2d(in_channels, in_channels // 4, kernel_size=3, stride=4, padding=1, output_padding=3)
    return up_


def up2(in_channels):
    up_ = nn.ConvTranspose2d(in_channels, in_channels // 8, kernel_size=3, stride=8, padding=1, output_padding=7)
    return up_

def up3(in_channels):
    up_ = nn.ConvTranspose2d(in_channels, in_channels // 16, kernel_size=3, stride=16, padding=1, output_padding=15)
    return up_


def final_block(in_channels, out_channels):
    block = nn.Sequential(
        nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.BatchNorm2d(out_channels),

    )
    return block

class Unet3P(nn.Module):
  def __init__(self,in_channel,out_channel):
    super(Unet3P,self).__init__()
    self.conv1=contracting_block(in_channel,64)
    self.conv2=contracting_block(64,128)
    self.conv3=contracting_block(128,256)
    self.conv4=contracting_block(256,512)
    self.conv5=contracting_block(512,1024)

    self.down=nn.MaxPool2d(2,2)
    self.down1=nn.MaxPool2d(4,4)
    self.down2=nn.MaxPool2d(8,8)
    self.down3=nn.MaxPool2d(16,16)
    self.down4=nn.MaxPool2d(32,32)

    self.up1 = up(64)
    self.up2 = up(128)
    self.up3 = up(256)
    self.up4 = up(512)
    self.up5 = up(1024)

    self.c4=contracting_block(512*5,512)
    self.c3=contracting_block(256*5,256)
    self.c2=contracting_block(128*5,128)
    self.c1=contracting_block(64*5,64)

    self.con1=contracting_block(64,512)
    self.con2=contracting_block(128,512)
    self.con3=contracting_block(256,512)
    
    
    self.con11=contracting_block(64,256)

    self.up55=up1(1024)
    self.up44=up1(512)
    self.up33=up1(256)
    
    self.up555=up2(1024)
    self.up444=up2(512)
    
    self.up5555=up3(1024)
    
    self.final=final_block(64,out_channel)
  def forward(self,x):
    x1=self.conv1(x)#1,64,512,512
    print('x1',x1.shape)
    x2=self.down(x1)
    x2=self.conv2(x2)#1,128,256,256
    print('x2',x2.shape)
    x3=self.down(x2)
    x3=self.conv3(x3)#1,256,128,128
    print('x3',x3.shape)
    x4=self.down(x3)
    x4=self.conv4(x4)#1,512,64,64
    print('x4',x4.shape)
    x5=self.down(x4)
    x5=self.conv5(x5)#1,1024,32,32
    print('x5',x5.shape)
      
    up4=self.c4(torch.cat([self.up5(x5),x4,self.con3(self.down(x3)),self.con2(self.down1(x2)),self.con1(self.down2(x1))],dim=1))
    print('up4',up4.shape)
    
    up3=self.c3(torch.cat([self.up55(x5),self.up4(up4),x3,self.conv3(self.down(x2)),self.con11(self.down1(x1))],dim=1))
    print('up3',up3.shape)
    
    up2=self.c2(torch.cat([self.up555(x5),self.up44(up4),self.up3(up3),x2,self.conv2(self.down(x1))],dim=1))
    print('up2',up2.shape)
    
    up1=self.c1(torch.cat([self.up5555(x5),self.up444(up4),self.up33(up3),self.up2(up2),x1],dim=1))
    print('up1',up1.shape)
    
    final=self.final(up1)
    return final

if __name__ == '__main__':
    print('*--' * 5)
    rgb = torch.randn(1, 3, 512, 512)
    net = Unet3P(3, 2)
    out = net(rgb)
    print(out.shape)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值