Pytorch语义分割(2)--------模型搭建

经典的模型还是Unet,也可以使用torch自带的unet来训练,但为了更好地了解,还是选择自己搭建。

unet.py:

import torch
import torch.nn as nn
import torch.nn.functional as F


class Up(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(Up, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, 3, 1, 1),
            nn.BatchNorm2d(out_channel),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.block(x)
        out = F.interpolate(x, scale_factor=2)
        return out


class Down(nn.Module):
    def __init__(self, in_channel, out_channel, stride=2):
        super(Down, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, 3, stride, 1),
            nn.BatchNorm2d(out_channel),
            nn.ReLU()
        )

    def forward(self, x):
        return self.block(x)


class UpConcat(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(UpConcat, self).__init__()
        self.up = nn.Upsample(scale_factor=2)
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channel+out_channel, out_channel, kernel_size=3, padding=1),
            nn.ReLU6(inplace=True),
            nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1),
            nn.ReLU6(inplace=True),
        )

    def forward(self, in_map1, in_map2):
        in_map2 = self.up(in_map2)
        out = torch.cat([in_map1, in_map2], dim=1)
        return self.conv2(out)


class MainNet(nn.Module):
    def __init__(self, num_classes):
        super(MainNet, self).__init__()
        self.down1 = Down(3, 64, stride=1)
        self.down2 = Down(64, 128)
        self.down3 = Down(128, 256)
        self.down4 = Down(256, 512)
        self.down5 = Down(512, 1024)

        # self.conv = nn.Conv2d(1024, 512, 3, 1, 1)

        self.up5concat = UpConcat(1024, 512)
        self.up4concat = UpConcat(512, 256)
        self.up3concat = UpConcat(256, 128)
        self.up2concat = UpConcat(128, 64)

        self.head = nn.Sequential(
            nn.Conv2d(64, num_classes, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        feat1 = self.down1(x)       # 3, 512, 512 ---->64, 512, 512
        feat2 = self.down2(feat1)   # 64, 512, 512 ---->128, 256, 256
        feat3 = self.down3(feat2)   # 128, 256, 256 ---->256,128,128
        feat4 = self.down4(feat3)   # 256,128,128 ---> 512,64,64
        feat5 = self.down5(feat4)   # 512,64,64 ----> 1024,32,32
        print("feat5:", feat5.shape)
        # feat5 = self.conv(feat5)

        feat4_up = self.up5concat(feat4, feat5)
        print("feat4_up:", feat4_up.shape)
        feat3_up = self.up4concat(feat3, feat4_up)
        feat2_up = self.up3concat(feat2, feat3_up)
        feat1_up = self.up2concat(feat1, feat2_up)
        print("feat1_up:", feat1_up.shape)

        print(feat1_up.shape, feat2_up.shape, feat3_up.shape, feat4_up.shape)
        return self.head(feat1_up)


if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tensor = torch.zeros((1, 3, 512, 512)).to(device)
    model = MainNet(num_classes=3).to(device)

    # print(model)
    # model.apply(inplace_relu)

    out = model(tensor)
    # print(out.shape)
    #
    from torchsummary import torchsummary
    torchsummary.summary(model, (3, 512, 512))
    # # from torchstat import stat
    # # stat(model, (3, 512, 512))
    # from thop import profile
    #
    # flops, params = profile(model, inputs=(tensor,))
    #
    # print("FLOPs=", str(flops / 1e9) + '{}'.format("G"))
    # print("params=", str(params / 1e6) + '{}'.format("M"))
    #
    # #FLOPs= 63.406604288G
    # # params= 14.127683M

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值