如有不足,欢迎指正。
Unet如下:
from torch import nn
import torch
import torch.nn.functional as F
def contracting_block(in_channels, out_channels):
block = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1),
nn.ReLU(),
nn.BatchNorm2d(out_channels),
nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1),
nn.ReLU(),
nn.BatchNorm2d(out_channels),
)
return block
class expansive_block(nn.Module):
def __init__(self, in_channels, mid_channels, out_channels):
super(expansive_block, self).__init__()
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=3, stride=2, padding=1, output_padding=1)
self.block = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=mid_channels, kernel_size=3, padding=1), # 添加填充
nn.ReLU(),
nn.BatchNorm2d(mid_channels),
nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=3, padding=1), # 添加填充
nn.ReLU(),
nn.BatchNorm2d(out_channels),
)
def forward(self, e, d):
d = self.up(d)
diffY = e.size()[2] - d.size()[2]
diffX = e.size()[3] - d.size()[3]
d = F.pad(d, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) # 对上采样后的张量进行填充
cat = torch.cat([e, d], dim=1)
out = self.block(cat)
return out
def final_block(in_channels, out_channels):
block = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1),
nn.ReLU(),
nn.BatchNorm2d(out_channels),
)
return block
class Unet(nn.Module):
def __init__(self, in_channels, out_channels):
super(Unet, self).__init__()
self.conv_encode1 = contracting_block(in_channels=in_channels, out_channels=64)
self.conv_pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv_encode2 = contracting_block(in_channels=64, out_channels=128)
self.conv_pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv_encode3 = contracting_block(in_channels=128, out_channels=256)
self.conv_pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv_encode4 = contracting_block(in_channels=256, out_channels=512)
self.conv_pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv_encode5 = contracting_block(in_channels=512, out_channels=1024)
self.conv_pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
self.bottleneck = nn.Sequential(
nn.Conv2d(kernel_size=3, in_channels=1024, out_channels=2048),
nn.ReLU(),
nn.BatchNorm2d(2048),
nn.Conv2d(kernel_size=3, in_channels=2048, out_channels=2048),
nn.ReLU(),
nn.BatchNorm2d(2048),
)
self.conv_decode5 = expansive_block(2048, 1024, 1024)
self.conv_decode4 = expansive_block(1024, 512, 512)
self.conv_decode3 = expansive_block(512, 256, 256)
self.conv_decode2 = expansive_block(256, 128, 128)
self.conv_decode1 = expansive_block(128, 64, 64)
self.final_layer = final_block(64, out_channels)
def forward(self, x):
encode_block1 = self.conv_encode1(x)
pool1 = self.conv_pool1(encode_block1)
encode_block2 = self.conv_encode2(pool1)
pool2 = self.conv_pool2(encode_block2)
encode_block3 = self.conv_encode3(pool2)
pool3 = self.conv_pool3(encode_block3)
encode_block4 = self.conv_encode4(pool3)
pool4 = self.conv_pool4(encode_block4)
encode_block5 = self.conv_encode5(pool4)
pool5 = self.conv_pool5(encode_block5)
bridge = self.bottleneck(pool5)
decoder5 = self.conv_decode5(encode_block5, bridge)
decoder4 = self.conv_decode4(encode_block4, decoder5)
decoder3 = self.conv_decode3(encode_block3, decoder4)
decoder2 = self.conv_decode2(encode_block2, decoder3)
decoder1 = self.conv_decode1(encode_block1, decoder2)
final_layer = self.final_layer(decoder1)
return final_layer
Unet2P如下:
from torch import nn
import torch
import torch.nn.functional as F
def contracting_block(in_channels, out_channels):
block = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1),
nn.ReLU(),
nn.BatchNorm2d(out_channels),
nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1),
nn.ReLU(),
nn.BatchNorm2d(out_channels),
)
return block
def up(in_channels):
up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=3, stride=2, padding=1, output_padding=1)
return up
def final_block(in_channels, out_channels):
block = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1),
nn.ReLU(),
nn.BatchNorm2d(out_channels),
)
return block
class Unet2P_(nn.Module):
def __init__(self, in_channels, out_channels, deep_supervision=False):
super(Unet2P_, self).__init__()
self.deep_supervision = deep_supervision
self.pool = nn.MaxPool2d(2, 2)
self.up0 = up(64)
self.up1 = up(128)
self.up2 = up(256)
self.up3 = up(512)
self.up4 = up(1024)
self.up5 = up(2048)
self.conv1 = contracting_block(in_channels, 64)
self.conv2 = contracting_block(64, 128)
self.conv3 = contracting_block(128, 256)
self.conv4 = contracting_block(256, 512)
self.bottleneck = nn.Sequential(
nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, padding=1),
nn.ReLU(),
nn.BatchNorm2d(1024),
nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, padding=1),
nn.ReLU(),
nn.BatchNorm2d(1024),
)
self.final_layer = final_block(64, out_channels)
self.con01 = contracting_block(128, 64)
self.con11 = contracting_block(256, 128)
self.con21 = contracting_block(512, 256)
self.con31 = contracting_block(1024, 512)
self.con02 = contracting_block(192, 64)
self.con12 = contracting_block(128 * 3, 128)
self.con22 = contracting_block(256 * 3, 256)
self.con03 = contracting_block(64 * 4, 64)
self.con13 = contracting_block(128 * 4, 128)
self.con04 = contracting_block(64 * 5, 64)
if self.deep_supervision:
self.final1 = nn.Conv2d(64, out_channels, kernel_size=3, padding=1)
self.final2 = nn.Conv2d(64, out_channels, kernel_size=3, padding=1)
self.final3 = nn.Conv2d(64, out_channels, kernel_size=3, padding=1)
self.final4 = nn.Conv2d(64, out_channels, kernel_size=3, padding=1)
else:
pass
def forward(self, x):
x0_0 = self.conv1(x)
x1_0 = self.conv2(self.pool(x0_0))
x2_0 = self.conv3(self.pool(x1_0))
x3_0 = self.conv4(self.pool(x2_0))
x4_0 = self.bottleneck(self.pool(x3_0))
x0_1 = self.con01(torch.cat([x0_0, self.up1(x1_0)], 1))
x1_1 = self.con11(torch.cat([x1_0, self.up2(x2_0)], 1))
x2_1 = self.con21(torch.cat([x2_0, self.up3(x3_0)], 1))
x3_1 = self.con31(torch.cat([x3_0, self.up4(x4_0)], 1))
x0_2 = self.con02(torch.cat([x0_0, x0_1, self.up1(x1_1)], 1))
x1_2 = self.con12(torch.cat([x1_0, x1_1, self.up2(x2_1)], 1))
x2_2 = self.con22(torch.cat([x2_0, x2_1, self.up3(x3_1)], 1))
x0_3 = self.con03(torch.cat([x0_0, x0_1, x0_2, self.up1(x1_2)], 1))
x1_3 = self.con13(torch.cat([x1_0, x1_1, x1_2, self.up2(x2_2)], 1))
x0_4 = self.con04(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up1(x1_3)], 1))
if self.deep_supervision:
output1 = self.final1(x0_1)
output2 = self.final2(x0_2)
output3 = self.final3(x0_3)
output4 = self.final4(x0_4)
return [output1, output2, output3, output4]
else:
final_ = self.final_layer(x0_4)
return final_
Unet3P如下:
from torch import nn
import torch
import torch.nn.functional as F
def contracting_block(in_channels, out_channels):
block = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1),
nn.ReLU(),
nn.BatchNorm2d(out_channels),
nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1),
nn.ReLU(),
nn.BatchNorm2d(out_channels),
)
return block
def up(in_channels):
up_ = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=3, stride=2, padding=1, output_padding=1)
return up_
def up1(in_channels):
up_ = nn.ConvTranspose2d(in_channels, in_channels // 4, kernel_size=3, stride=4, padding=1, output_padding=3)
return up_
def up2(in_channels):
up_ = nn.ConvTranspose2d(in_channels, in_channels // 8, kernel_size=3, stride=8, padding=1, output_padding=7)
return up_
def up3(in_channels):
up_ = nn.ConvTranspose2d(in_channels, in_channels // 16, kernel_size=3, stride=16, padding=1, output_padding=15)
return up_
def final_block(in_channels, out_channels):
block = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1),
nn.ReLU(),
nn.BatchNorm2d(out_channels),
)
return block
class Unet3P(nn.Module):
def __init__(self,in_channel,out_channel):
super(Unet3P,self).__init__()
self.conv1=contracting_block(in_channel,64)
self.conv2=contracting_block(64,128)
self.conv3=contracting_block(128,256)
self.conv4=contracting_block(256,512)
self.conv5=contracting_block(512,1024)
self.down=nn.MaxPool2d(2,2)
self.down1=nn.MaxPool2d(4,4)
self.down2=nn.MaxPool2d(8,8)
self.down3=nn.MaxPool2d(16,16)
self.down4=nn.MaxPool2d(32,32)
self.up1 = up(64)
self.up2 = up(128)
self.up3 = up(256)
self.up4 = up(512)
self.up5 = up(1024)
self.c4=contracting_block(512*5,512)
self.c3=contracting_block(256*5,256)
self.c2=contracting_block(128*5,128)
self.c1=contracting_block(64*5,64)
self.con1=contracting_block(64,512)
self.con2=contracting_block(128,512)
self.con3=contracting_block(256,512)
self.con11=contracting_block(64,256)
self.up55=up1(1024)
self.up44=up1(512)
self.up33=up1(256)
self.up555=up2(1024)
self.up444=up2(512)
self.up5555=up3(1024)
self.final=final_block(64,out_channel)
def forward(self,x):
x1=self.conv1(x)#1,64,512,512
print('x1',x1.shape)
x2=self.down(x1)
x2=self.conv2(x2)#1,128,256,256
print('x2',x2.shape)
x3=self.down(x2)
x3=self.conv3(x3)#1,256,128,128
print('x3',x3.shape)
x4=self.down(x3)
x4=self.conv4(x4)#1,512,64,64
print('x4',x4.shape)
x5=self.down(x4)
x5=self.conv5(x5)#1,1024,32,32
print('x5',x5.shape)
up4=self.c4(torch.cat([self.up5(x5),x4,self.con3(self.down(x3)),self.con2(self.down1(x2)),self.con1(self.down2(x1))],dim=1))
print('up4',up4.shape)
up3=self.c3(torch.cat([self.up55(x5),self.up4(up4),x3,self.conv3(self.down(x2)),self.con11(self.down1(x1))],dim=1))
print('up3',up3.shape)
up2=self.c2(torch.cat([self.up555(x5),self.up44(up4),self.up3(up3),x2,self.conv2(self.down(x1))],dim=1))
print('up2',up2.shape)
up1=self.c1(torch.cat([self.up5555(x5),self.up444(up4),self.up33(up3),self.up2(up2),x1],dim=1))
print('up1',up1.shape)
final=self.final(up1)
return final
if __name__ == '__main__':
print('*--' * 5)
rgb = torch.randn(1, 3, 512, 512)
net = Unet3P(3, 2)
out = net(rgb)
print(out.shape)