【语义分割-1】使用pytorch搭建Unet

使用pytorch搭建Unet

import torch
import torch.nn as nn
from torch.nn import functional as F

#  上采样方式使用 nn.UpsamplingBilinear2d(scale_factor=2)

def conv2_relu(in_chan, out_chan, pd):

    return nn.Sequential(
        nn.Conv2d(in_channels=in_chan, out_channels=out_chan, kernel_size=3, padding=pd),
        nn.ReLU(inplace=True),
        nn.Conv2d(in_channels=out_chan, out_channels=out_chan, kernel_size=3, padding=pd),
        nn.ReLU(inplace=True)
    )


def conv3_relu(in_chan, out_chan, pd):

    return nn.Sequential(
        nn.Conv2d(in_channels=in_chan, out_channels=out_chan, kernel_size=3, padding=pd),
        nn.ReLU(inplace=True),
        nn.Conv2d(in_channels=out_chan, out_channels=out_chan, kernel_size=3, padding=pd),
        nn.ReLU(inplace=True),
        nn.Conv2d(in_channels=out_chan, out_channels=out_chan, kernel_size=3, padding=pd),
        nn.ReLU(inplace=True)
    )


class U_Net(nn.Module):

    def __init__(self, num_cls):
        super(U_Net, self).__init__()

        self.num_cls = num_cls

        # 编码器卷积
        self.conv_1 = conv2_relu(in_chan=3, out_chan=64, pd=1)
        self.conv_2 = conv2_relu(in_chan=64, out_chan=128, pd=1)
        self.conv_3 = conv3_relu(in_chan=128, out_chan=256, pd=1)
        self.conv_4 = conv3_relu(in_chan=256, out_chan=512, pd=1)
        self.conv_5 = conv3_relu(in_chan=512, out_chan=512, pd=1)

        # 解码器卷积
        self.conv_6 = conv2_relu(in_chan=1024, out_chan=512, pd=1)
        self.conv_7 = conv2_relu(in_chan=768, out_chan=256, pd=1)
        self.conv_8 = conv2_relu(in_chan=384, out_chan=128, pd=1)
        self.conv_9 = conv2_relu(in_chan=192, out_chan=64, pd=1)
        self.conv_final = nn.Conv2d(in_channels=64, out_channels=self.num_cls, kernel_size=1)

        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.upsample = nn.UpsamplingBilinear2d(scale_factor=2)

    def forward(self,x):

        # 编码器部分 下采样
        # x:(N, 3, 416, 416)
        f1 = self.conv_1(x)  # (N, 64, 416, 416)
        x = self.maxpool(f1)  # (N, 64, 208, 208)

        f2 = self.conv_2(x)  # (N, 128, 208, 208)
        x = self.maxpool(f2)  # (N, 128, 104, 104)

        f3 = self.conv_3(x)  # (N, 256, 104, 104)
        x = self.maxpool(f3)  # (N, 256, 52, 52)

        f4 = self.conv_4(x)  # (N, 512, 52, 52)
        x = self.maxpool(f4)  # (N, 512, 26, 26)

        f5 = self.conv_5(x)  # (N, 512, 26, 26)

        #  解码器部分 上采样
        p5_up = self.upsample(f5)  # (N, 512, 52, 52)

        p4 = torch.cat([f4, p5_up], dim=1)  # (N, 1024, 52, 52); 512+512=1024
        p4 = self.conv_6(p4)  # (N, 512, 52, 52)
        p4_up = self.upsample(p4)  # (N, 512, 104, 104)

        p3 = torch.cat([f3, p4_up], dim=1)  # (N, 768, 104, 104); 256+512=768
        p3 = self.conv_7(p3)  # (N, 256, 104, 104)
        p3_up = self.upsample(p3)  # (N, 256, 208, 208)

        p2 = torch.cat([f2, p3_up], dim=1)  # (N, 384, 208, 208); 128+256=384
        p2 = self.conv_8(p2)  # (N, 128, 208, 208)
        p2_up = self.upsample(p2)  # (N, 128, 416, 416)

        p1 = torch.cat([f1, p2_up], dim=1)  # (N, 192, 416, 416)
        p1 = self.conv_9(p1)  # (N, 64, 416, 416)
        p1 = self.conv_final(p1)  # (N, cls_num, 416, 416)

        return p1


model = U_Net(num_cls=9)
data = torch.rand([1, 3, 416, 416]).float()
out = model(data)
print('out.shape:', out.shape)

输出结果:

out.shape: torch.Size([1, 9, 416, 416])

参考链接:
憨批的语义分割3——unet模型详解以及训练自己的unet模型(划分斑马线)

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值