Unet模型实现

class Double_conv2d_bn_encoder(nn.Module):
    def __init__(self, in_channels1, in_channels2, out_channels):
        super(Double_conv2d_bn_encoder, self).__init__()
        self.d_conv2d_bn_1 = nn.Sequential(
            nn.Conv2d(in_channels1, out_channels, kernel_size=3, padding=1),
        )
        self.d_conv2d_bn_2 = nn.Sequential(
            nn.Conv2d(in_channels2, out_channels, kernel_size=3, padding=1),
        )
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
    def forward(self, input):
        output = F.relu(self.bn1(self.d_conv2d_bn_1(input)))
        double_conv2d_bn_decoder = output = F.relu(self.bn2(self.d_conv2d_bn_2(output)))
        mp = nn.MaxPool2d(kernel_size=2, stride=2)
        output = mp(output)
        return output, double_conv2d_bn_decoder
class Double_conv2d_bn_decoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Double_conv2d_bn_decoder, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=0)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
    def forward(self, cat_encoder_layer, input):
        cater = torch.cat([cat_encoder_layer, input], dim=1)
        # print(cater.shape)
        output = F.relu(self.bn1(self.conv1(cater)))
        output = F.relu(self.bn2(self.conv2(output)))
        print("等待双线性内插", output.shape)
        return output
class Model_unet(nn.Module):
    def __init__(self, in_channels):
        super(Model_unet, self).__init__()
        # encode
        self.layer1 = nn.Sequential(
            Double_conv2d_bn_encoder(in_channels, 64, 64),
        )
        self.layer2 = nn.Sequential(
            Double_conv2d_bn_encoder(64, 128, 128)
        )
        self.layer3 = Double_conv2d_bn_encoder(128, 256, 256)
        self.layer4 = Double_conv2d_bn_encoder(256, 512, 512)
        self.layer5_1 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=0)
        self.layer5_1_1 = nn.BatchNorm2d(1024)
        self.layer5_2 = nn.Conv2d(1024, 1024, kernel_size=3, stride=1, padding=0)
        self.layer5_2_1 = nn.BatchNorm2d(1024)
        self.layer5_3 = nn.ConvTranspose2d(1024, out_channels=512,
                                           kernel_size=1,
                                           stride=3, padding=0, bias=True)
        self.layer6 = Double_conv2d_bn_decoder(1024, 512)
        self.layer7 = Double_conv2d_bn_decoder(512, 256)
        self.layer8 = Double_conv2d_bn_decoder(256, 128)
        self.layer9_conv1 = nn.Conv2d(128, 64, kernel_size=3, padding=0)
        self.layer9_conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=0)
        self.layer9_bn1 = nn.BatchNorm2d(64)
        self.layer9_bn2 = nn.BatchNorm2d(64)
        self.layer10 = nn.Conv2d(64, 3, kernel_size=1, stride=1, padding=2)
    def forward(self, input):
        output, save_layer1 = self.layer1(input)
        output, save_layer2 = self.layer2(output)
        output, save_layer3 = self.layer3(output)
        output, save_layer4 = self.layer4(output)
        output = self.layer5_1(output)
        output = F.relu(self.layer5_1_1(output))
        output = self.layer5_2(output)
        output = F.relu(self.layer5_2_1(output))
        output = self.layer5_3(output)
        output = self.layer6(save_layer4, output)
        con = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=4, padding=19, bias=True)
        output = con(output)
        output = self.layer7(save_layer3, output)
        con = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=4, padding=47, bias=True)
        output = con(output)
        output = self.layer8(save_layer2, output)
        con = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=4, padding=103, bias=True)
        output = con(output)
        output = torch.cat([save_layer1, output], dim=1)
        output = F.selu(self.layer9_bn1(self.layer9_conv1(output)))
        output = F.selu(self.layer9_bn2(self.layer9_conv2(output)))
        output = self.layer10(output)
        return output

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值