【Unet】允许修改深度和通道数的Unet,pytorch实现

看了很多Unet的代码,基本上都是固定深度和通道数的,不方便修改,自己写了一个可以修改的版本。pytorch实现,encoder和decoder用的都是残差块。
修改num_features就可以改变网络的深度和通道数,如num_features=[48, 96, 192, 384, …]。这个列表可以任意长,一个数字也可以任意填,但后面的每个数字都要是前一个的二倍。

pytorch代码:

import torch
import torch.nn as nn


# 3x3卷积层
def conv3x3(in_channels, out_channels):
    return nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=False)

# 残差块
class ResidualBlock(nn.Module):
    def __init__(self, input_dim, output_dim, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
        super().__init__()
        self.convs = nn.Sequential(  # 两层卷积
            conv3x3(input_dim, output_dim), norm_layer(output_dim), act_layer(),  # conv+bn+relu
            conv3x3(output_dim, output_dim), norm_layer(output_dim), act_layer()
        )
        self.shortcut = nn.Conv2d(input_dim, output_dim, kernel_size=1) if input_dim != output_dim \
            else nn.Identity()  # 输入输出通道数不同时才用1x1卷积
    
    def forward(self, x):
        return self.convs(x) + self.shortcut(x)

# 下采样
class Down(nn.Module):
    def __init__(self, input_dim, output_dim, ratio=2, layer='conv'):
        super().__init__()
        self.res_block = ResidualBlock(input_dim, output_dim)
        if layer == 'conv':
            self.down = nn.Sequential(
                nn.Conv2d(output_dim, output_dim, kernel_size=ratio+1, stride=ratio, padding=1), # 卷积下采样
                nn.ReLU()
            )
        else:
            self.down = nn.AvgPool2d(kernel_size=ratio, stride=ratio) # 池化下采样
    
    def forward(self, x):
        x = self.res_block(x)
        x = self.down(x)
        return x

# 上采样
class Up(nn.Module):
    def __init__(self, input_dim, output_dim, ratio=2):
        super().__init__()
        self.up = nn.Upsample(scale_factor=ratio, mode='nearest')  # 上采样
        self.reduce = nn.Conv2d(input_dim, input_dim // 2, kernel_size=1) # 1x1卷积将通道数降低1/2
        self.res_block = ResidualBlock(input_dim, output_dim)
    
    def forward(self, x, r): # r是跨层连接
        x = self.up(x)
        x = self.reduce(x)
        # 拼接
        x = torch.cat((x, r), dim=1)
        x = self.res_block(x)
        return x


# 自己写的, 可修改深度和通道数的Unet
class Unet(nn.Module):
    def __init__(self, input_dim, output_dim, num_features=[64, 128, 256, 512, 1024]):  # 任意深度和通道数, 只要后一个是前一个的2倍就行
        super().__init__()
        self.conv_1 = nn.Sequential(
            conv3x3(input_dim, num_features[0]), nn.BatchNorm2d(num_features[0]), nn.ReLU()
        )
        
        self.encoder_x = nn.ModuleList([  # x个encoder, x=len(num_features)-1
            Down(in_dim, out_dim) for in_dim, out_dim in zip(num_features, num_features[1:]) # 用for循环来创建, 放到ModuleList里面
        ])

        self.decoder_x = nn.ModuleList([  # x个decoder, x=len(num_features)-1
            Up(in_dim, out_dim) for in_dim, out_dim in zip(reversed(num_features), reversed(num_features[:-1]))
        ])
        
        self.conv_out = conv3x3(num_features[0], output_dim)
    
    def forward(self, x):
        x = self.conv_1(x)
        e_outputs = []  # 记录encoder的输出
        
        # 下采样
        for down in self.encoder_x:
            e_outputs.append(x)
            x = down(x)
        
        # 上采样
        for up, r in zip(self.decoder_x, reversed(e_outputs)):  # encoder的输出要反过来
            x = up(x, r)  # r是跨层连接
        
        x = self.conv_out(x)

        if x.shape[1] == 1:  # 一通道时加sigmoid
            x = torch.sigmoid(x)
        return x

# 测试
model = Unet(3, 1)
x = torch.randn(4, 3, 64, 64)
y = model(x)
print(y.shape)

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值