class Double_conv2d_bn_encoder(nn.Module):
def __init__(self, in_channels1, in_channels2, out_channels):
super(Double_conv2d_bn_encoder, self).__init__()
self.d_conv2d_bn_1 = nn.Sequential(
nn.Conv2d(in_channels1, out_channels, kernel_size=3, padding=1),
)
self.d_conv2d_bn_2 = nn.Sequential(
nn.Conv2d(in_channels2, out_channels, kernel_size=3, padding=1),
)
self.bn1 = nn.BatchNorm2d(out_channels)
self.bn2 = nn.BatchNorm2d(out_channels)
def forward(self, input):
output = F.relu(self.bn1(self.d_conv2d_bn_1(input)))
double_conv2d_bn_decoder = output = F.relu(self.bn2(self.d_conv2d_bn_2(output)))
mp = nn.MaxPool2d(kernel_size=2, stride=2)
output = mp(output)
return output, double_conv2d_bn_decoder
class Double_conv2d_bn_decoder(nn.Module):
def __init__(self, in_channels, out_channels):
super(Double_conv2d_bn_decoder, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=0)
self.bn1 = nn.BatchNorm2d(out_channels)
self.bn2 = nn.BatchNorm2d(out_channels)
def forward(self, cat_encoder_layer, input):
cater = torch.cat([cat_encoder_layer, input], dim=1)
# print(cater.shape)
output = F.relu(self.bn1(self.conv1(cater)))
output = F.relu(self.bn2(self.conv2(output)))
print("等待双线性内插", output.shape)
return output
class Model_unet(nn.Module):
def __init__(self, in_channels):
super(Model_unet, self).__init__()
# encode
self.layer1 = nn.Sequential(
Double_conv2d_bn_encoder(in_channels, 64, 64),
)
self.layer2 = nn.Sequential(
Double_conv2d_bn_encoder(64, 128, 128)
)
self.layer3 = Double_conv2d_bn_encoder(128, 256, 256)
self.layer4 = Double_conv2d_bn_encoder(256, 512, 512)
self.layer5_1 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=0)
self.layer5_1_1 = nn.BatchNorm2d(1024)
self.layer5_2 = nn.Conv2d(1024, 1024, kernel_size=3, stride=1, padding=0)
self.layer5_2_1 = nn.BatchNorm2d(1024)
self.layer5_3 = nn.ConvTranspose2d(1024, out_channels=512,
kernel_size=1,
stride=3, padding=0, bias=True)
self.layer6 = Double_conv2d_bn_decoder(1024, 512)
self.layer7 = Double_conv2d_bn_decoder(512, 256)
self.layer8 = Double_conv2d_bn_decoder(256, 128)
self.layer9_conv1 = nn.Conv2d(128, 64, kernel_size=3, padding=0)
self.layer9_conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=0)
self.layer9_bn1 = nn.BatchNorm2d(64)
self.layer9_bn2 = nn.BatchNorm2d(64)
self.layer10 = nn.Conv2d(64, 3, kernel_size=1, stride=1, padding=2)
def forward(self, input):
output, save_layer1 = self.layer1(input)
output, save_layer2 = self.layer2(output)
output, save_layer3 = self.layer3(output)
output, save_layer4 = self.layer4(output)
output = self.layer5_1(output)
output = F.relu(self.layer5_1_1(output))
output = self.layer5_2(output)
output = F.relu(self.layer5_2_1(output))
output = self.layer5_3(output)
output = self.layer6(save_layer4, output)
con = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=4, padding=19, bias=True)
output = con(output)
output = self.layer7(save_layer3, output)
con = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=4, padding=47, bias=True)
output = con(output)
output = self.layer8(save_layer2, output)
con = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=4, padding=103, bias=True)
output = con(output)
output = torch.cat([save_layer1, output], dim=1)
output = F.selu(self.layer9_bn1(self.layer9_conv1(output)))
output = F.selu(self.layer9_bn2(self.layer9_conv2(output)))
output = self.layer10(output)
return output
Unet模型实现
最新推荐文章于 2024-04-01 17:20:14 发布