看了很多Unet的代码,基本上都是固定深度和通道数的,不方便修改,自己写了一个可以修改的版本。pytorch实现,encoder和decoder用的都是残差块。
修改num_features就可以改变网络的深度和通道数,如num_features=[48, 96, 192, 384, …]。这个列表可以任意长,一个数字也可以任意填,但后面的每个数字都要是前一个的二倍。
pytorch代码:
import torch
import torch.nn as nn
# 3x3卷积层
def conv3x3(in_channels, out_channels):
return nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=False)
# 残差块
class ResidualBlock(nn.Module):
def __init__(self, input_dim, output_dim, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
super().__init__()
self.convs = nn.Sequential( # 两层卷积
conv3x3(input_dim, output_dim), norm_layer(output_dim), act_layer(), # conv+bn+relu
conv3x3(output_dim, output_dim), norm_layer(output_dim), act_layer()
)
self.shortcut = nn.Conv2d(input_dim, output_dim, kernel_size=1) if input_dim != output_dim \
else nn.Identity() # 输入输出通道数不同时才用1x1卷积
def forward(self, x):
return self.convs(x) + self.shortcut(x)
# 下采样
class Down(nn.Module):
def __init__(self, input_dim, output_dim, ratio=2, layer='conv'):
super().__init__()
self.res_block = ResidualBlock(input_dim, output_dim)
if layer == 'conv':
self.down = nn.Sequential(
nn.Conv2d(output_dim, output_dim, kernel_size=ratio+1, stride=ratio, padding=1), # 卷积下采样
nn.ReLU()
)
else:
self.down = nn.AvgPool2d(kernel_size=ratio, stride=ratio) # 池化下采样
def forward(self, x):
x = self.res_block(x)
x = self.down(x)
return x
# 上采样
class Up(nn.Module):
def __init__(self, input_dim, output_dim, ratio=2):
super().__init__()
self.up = nn.Upsample(scale_factor=ratio, mode='nearest') # 上采样
self.reduce = nn.Conv2d(input_dim, input_dim // 2, kernel_size=1) # 1x1卷积将通道数降低1/2
self.res_block = ResidualBlock(input_dim, output_dim)
def forward(self, x, r): # r是跨层连接
x = self.up(x)
x = self.reduce(x)
# 拼接
x = torch.cat((x, r), dim=1)
x = self.res_block(x)
return x
# 自己写的, 可修改深度和通道数的Unet
class Unet(nn.Module):
def __init__(self, input_dim, output_dim, num_features=[64, 128, 256, 512, 1024]): # 任意深度和通道数, 只要后一个是前一个的2倍就行
super().__init__()
self.conv_1 = nn.Sequential(
conv3x3(input_dim, num_features[0]), nn.BatchNorm2d(num_features[0]), nn.ReLU()
)
self.encoder_x = nn.ModuleList([ # x个encoder, x=len(num_features)-1
Down(in_dim, out_dim) for in_dim, out_dim in zip(num_features, num_features[1:]) # 用for循环来创建, 放到ModuleList里面
])
self.decoder_x = nn.ModuleList([ # x个decoder, x=len(num_features)-1
Up(in_dim, out_dim) for in_dim, out_dim in zip(reversed(num_features), reversed(num_features[:-1]))
])
self.conv_out = conv3x3(num_features[0], output_dim)
def forward(self, x):
x = self.conv_1(x)
e_outputs = [] # 记录encoder的输出
# 下采样
for down in self.encoder_x:
e_outputs.append(x)
x = down(x)
# 上采样
for up, r in zip(self.decoder_x, reversed(e_outputs)): # encoder的输出要反过来
x = up(x, r) # r是跨层连接
x = self.conv_out(x)
if x.shape[1] == 1: # 一通道时加sigmoid
x = torch.sigmoid(x)
return x
# 测试
model = Unet(3, 1)
x = torch.randn(4, 3, 64, 64)
y = model(x)
print(y.shape)