Yolov4网络代码Pytorch

按照上图的架构
按照上图的架构重新自己写了分YOLOV4代码:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary


# Mish = x*tanh(ln(1+e^x))
class Mish(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, x):
        x = x * (torch.tanh(F.softplus(x)))
        return x


# CON + Mish + Batchnormal
class CMB(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super().__init__()
        self.mish = Mish()
        self.CB = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
                                nn.BatchNorm2d(out_channels),
                                )

    def forward(self, x):
        CB_out = self.CB(x)
        CMBout = self.mish(CB_out)
        return CMBout


class ResidualLayer(nn.Module):

    def __init__(self, in_channels):
        super(ResidualLayer, self).__init__()

        self.Resnet = nn.Sequential(
            CMB(in_channels, in_channels // 2, 1, 1, 0),
            CMB(in_channels // 2, in_channels, 3, 1, 1),
        )

    def forward(self, x):
        return x + self.Resnet(x)


class ResidualLayer_1(nn.Module):
    def __init__(self, in_channels):
        super(ResidualLayer_1, self).__init__()

        self.Resnet = nn.Sequential(
            CMB(in_channels, in_channels, 1, 1, 0),
            CMB(in_channels, in_channels, 3, 1, 1),
        )

    def forward(self, x):
        return x + self.Resnet(x)


class CSPnet(nn.Module):

    def __init__(self, in_channels, out_channel):
        super().__init__()
        self.CMB1 = CMB(in_channels, out_channel, 3, 2, 1)
        self.Seq = nn.Sequential(CMB(out_channel, out_channel, 1, 1, 0),
                                 ResidualLayer(64),
                                 CMB(out_channel, out_channel, 1, 1, 0))
        self.CMB2 = CMB(out_channel, out_channel, 1, 1, 0)
        self.CMB = CMB(out_channel * 2, out_channel, 1, 1, 0)

    def forward(self, x):
        CMB1_out = self.CMB1(x)
        Seq_out = self.Seq(CMB1_out)
        CMB2_out = self.CMB2(CMB1_out)
        Seq_cat_CMB2 = torch.cat([Seq_out, CMB2_out], dim=1)
        CSP_out = self.CMB(Seq_cat_CMB2)
        return CSP_out


class CSPnet_2(nn.Module):

    def __init__(self):
        super().__init__()
        self.CMB1 = CMB(64, 128, 3, 2, 1)
        self.Seq = nn.Sequential(CMB(128, 64, 1, 1, 0),
                                 ResidualLayer_1(64),
                                 ResidualLayer(64),
                                 CMB(64, 64, 1, 1, 0))
        self.CMB2 = CMB(128, 64, 1, 1, 0)
        self.CMB = CMB(128, 128, 1, 1, 0)

    def forward(self, x):
        CMB1_out = self.CMB1(x)
        Seq_out = self.Seq(CMB1_out)
        CMB2_out = self.CMB2(CMB1_out)
        Seq_cat_CMB2 = torch.cat([Seq_out, CMB2_out], dim=1)
        CSPnet_2_out = self.CMB(Seq_cat_CMB2)
        return CSPnet_2_out


class CSPnet_8(nn.Module):

    def __init__(self, inchannel, outchannel):  # 第一次128, 256 第二次256, 512
        super().__init__()
        self.CMB1 = CMB(inchannel, outchannel, 3, 2, 1)
        self.Seq = nn.Sequential(CMB(outchannel, inchannel, 1, 1, 0),
                                 ResidualLayer(inchannel),
                                 ResidualLayer(inchannel),
                                 ResidualLayer(inchannel),
                                 ResidualLayer(inchannel),
                                 ResidualLayer(inchannel),
                                 ResidualLayer(inchannel),
                                 ResidualLayer(inchannel),
                                 ResidualLayer(inchannel),
                                 CMB(inchannel, inchannel, 1, 1, 0))
        self.CMB2 = CMB(outchannel, inchannel, 1, 1, 0)
        self.CMB = CMB(outchannel, outchannel, 1, 1, 0)

    def forward(self, x):
        CMB1_out = self.CMB1(x)
        Seq_out = self.Seq(CMB1_out)
        CMB2_out = self.CMB2(CMB1_out)
        Seq_cat_CMB2 = torch.cat([Seq_out, CMB2_out], dim=1)
        CSPnet_8_out = self.CMB(Seq_cat_CMB2)
        return CSPnet_8_out


class CSPnet_4(nn.Module):

    def __init__(self, inchannel, outchannel):  # 第一次128, 256 第二次256, 512
        super().__init__()
        self.CMB1 = CMB(inchannel, outchannel, 3, 2, 1)
        self.Seq = nn.Sequential(CMB(outchannel, inchannel, 1, 1, 0),
                                 ResidualLayer(inchannel),
                                 ResidualLayer(inchannel),
                                 ResidualLayer(inchannel),
                                 ResidualLayer(inchannel),
                                 CMB(inchannel, inchannel, 1, 1, 0))
        self.CMB2 = CMB(outchannel, inchannel, 1, 1, 0)
        self.CMB = CMB(outchannel, outchannel, 1, 1, 0)

    def forward(self, x):
        CMB1_out = self.CMB1(x)
        Seq_out = self.Seq(CMB1_out)
        CMB2_out = self.CMB2(CMB1_out)
        Seq_cat_CMB2 = torch.cat([Seq_out, CMB2_out], dim=1)
        CSPnet_4_out = self.CMB(Seq_cat_CMB2)
        return CSPnet_4_out


class CBL(nn.Module):
    def __init__(self, inchannel, outchannel, CBL):
        super().__init__()
        if CBL == 'once':
            self.CBL = nn.Sequential(nn.Conv2d(inchannel, outchannel, 1, 1, 0),
                                     nn.BatchNorm2d(outchannel),
                                     nn.LeakyReLU(0.1),

                                     nn.Conv2d(outchannel, inchannel, 3, 1, 1),
                                     nn.BatchNorm2d(inchannel),
                                     nn.LeakyReLU(0.1),

                                     nn.Conv2d(inchannel, outchannel, 1, 1, 0),
                                     nn.BatchNorm2d(outchannel),
                                     nn.LeakyReLU(0.1)
                                     )
        elif CBL == 'second':
            self.CBL = nn.Sequential(nn.Conv2d(inchannel, inchannel // 4, 1, 1, 0),
                                     nn.BatchNorm2d(inchannel // 4),
                                     nn.LeakyReLU(0.1),

                                     nn.Conv2d(inchannel // 4, inchannel // 2, 3, 1, 1),
                                     nn.BatchNorm2d(inchannel // 2),
                                     nn.LeakyReLU(0.1),

                                     nn.Conv2d(inchannel // 2, inchannel // 4, 1, 1, 0),
                                     nn.BatchNorm2d(inchannel // 4),
                                     nn.LeakyReLU(0.1), )
        elif CBL == 'three':
            self.CBL = nn.Sequential(nn.Conv2d(inchannel, inchannel // 2, 1, 1, 0),
                                     nn.BatchNorm2d(inchannel // 2),
                                     nn.LeakyReLU(0.1), )
        elif CBL == 'four':
            self.CBL = nn.Sequential(nn.Conv2d(inchannel, inchannel // 2, 1, 1, 0),
                                     nn.BatchNorm2d(inchannel // 2),
                                     nn.LeakyReLU(0.1),
                                     nn.Conv2d(inchannel // 2, inchannel, 3, 1, 1),
                                     nn.BatchNorm2d(inchannel),
                                     nn.LeakyReLU(0.1),
                                     nn.Conv2d(inchannel, inchannel // 2, 1, 1, 0),
                                     nn.BatchNorm2d(inchannel // 2),
                                     nn.LeakyReLU(0.1),
                                     nn.Conv2d(inchannel // 2, inchannel, 3, 1, 1),
                                     nn.BatchNorm2d(inchannel),
                                     nn.LeakyReLU(0.1),
                                     nn.Conv2d(inchannel, inchannel // 2, 1, 1, 0),
                                     nn.BatchNorm2d(inchannel // 2),
                                     nn.LeakyReLU(0.1), )
        elif CBL == 'five':
            self.CBL = nn.Sequential(nn.Conv2d(inchannel, inchannel * 2, 1, 1, 0),
                                     nn.BatchNorm2d(inchannel * 2),
                                     nn.LeakyReLU(0.1), )

    def forward(self, x):
        return self.CBL(x)


class SPP(nn.Module):
    def __init__(self):
        super().__init__()
        self.maxpool1 = nn.MaxPool2d(kernel_size=5, stride=1, padding=5 // 2)
        self.maxpool2 = nn.MaxPool2d(kernel_size=9, stride=1, padding=9 // 2)
        self.maxpool3 = nn.MaxPool2d(kernel_size=13, stride=1, padding=13 // 2)

    def forward(self, x):
        maxpool1_out = self.maxpool1(x)
        maxpool2_out = self.maxpool2(x)
        maxpool3_out = self.maxpool3(x)
        spp_out = torch.cat([maxpool1_out, maxpool2_out, maxpool3_out, x], dim=1)
        return spp_out


class UpsampleLayer(torch.nn.Module):

    def __init__(self):
        super(UpsampleLayer, self).__init__()

    def forward(self, x):
        return F.interpolate(x, scale_factor=2, mode='nearest')


class Con(nn.Module):
    def __init__(self, inchannel, outchannel):
        super().__init__()
        self.Con2d = nn.Conv2d(inchannel, outchannel, 3, 1, 1)

    def forward(self, x):
        out = self.Con2d(x)
        return out

class downsample(nn.Module):
    def __init__(self, inchannel, outchannel):
        super().__init__()
        self.Con2d = nn.Conv2d(inchannel, outchannel, 3, 2, 1)

    def forward(self, x):
        out = self.Con2d(x)
        return out


class Mainnet(nn.Module):
    def __init__(self):
        super().__init__()
        self.CMB = CMB(3, 32, 3, 1, 1)
        self.CSP = CSPnet(32, 64)
        self.CSP2 = CSPnet_2()
        self.CSP8_1 = CSPnet_8(128, 256)
        self.CSP8_2 = CSPnet_8(256, 512)
        self.CSP4 = CSPnet_4(512, 1024)
        self.CBL_1 = CBL(1024, 512, 'once')
        self.Spp = SPP()
        self.CBL_2 = CBL(2048, 512, 'second')
        self.CBL_3 = CBL(512, 256, 'three')
        self.up_1 = UpsampleLayer()
        self.CBL_4 = CBL(512, 256, 'four')
        self.CBL_3_1 = CBL(256, 128, 'three')
        self.CBL_4_1 = CBL(256, 128, 'four')
        self.CBL_5 = CBL(128, 256, 'five')
        self.down = downsample(128, 256)
        self.con = Con(256, 255)
        self.CBL_4_2 = CBL(512, 256, 'four')
        self.CBL_5_1 = CBL(256, 512, 'five')
        self.down1 = downsample(256, 512)
        self.con_1 = Con(512, 255)
        self.CBL_4_3 = CBL(1024, 512, 'four')
        self.CBL_5_2 = CBL(512, 1024, 'five')
        self.down2 = downsample(512, 1024)
        self.con_2 = Con(1024, 255)


    def forward(self, x):
        CMB_out = self.CMB(x)
        # CSP
        CSP1_out = self.CSP(CMB_out)
        CSP2_out = self.CSP2(CSP1_out)
        CSP8_out_1 = self.CSP8_1(CSP2_out)
        CSP8_out_2 = self.CSP8_2(CSP8_out_1)
        CSP4_out = self.CSP4(CSP8_out_2)
        # CBL
        CBL1_out = self.CBL_1(CSP4_out)

        # SPP+CBL
        Spp_out = self.Spp(CBL1_out)
        CBL2_out = self.CBL_2(Spp_out)

        # CBL + upsample
        CBL3_out = self.CBL_3(CBL2_out)
        up1_out = self.up_1(CBL3_out)
        CBL_up_1 = self.CBL_3(CSP8_out_2)
        up1_cat = torch.cat([up1_out, CBL_up_1], dim=1)
        # print(up1_cat.shape)

        # CBL*5 + up + cat
        CBL_4 = self.CBL_4(up1_cat)
        CBL_3_1_out = self.CBL_3_1(CBL_4)
        up2_out = self.up_1(CBL_3_1_out)
        CBL_up_2 = self.CBL_3_1(CSP8_out_1)
        up1_cat = torch.cat([up2_out, CBL_up_2], dim=1)
        CBL_4_1_out = self.CBL_4_1(up1_cat)
        down_out = self.down(CBL_4_1_out)
        # print(CBL_4_1_out.shape)
        CBL_5 = self.CBL_5(CBL_4_1_out)
        out_1 = self.con(CBL_5)

        cat_3 = torch.cat([down_out,CBL_4], dim=1)
        CBL_4_2_out = self.CBL_4_2(cat_3)
        down1_out = self.down1(CBL_4_2_out)
        CBL_5_1_out = self.CBL_5_1(CBL_4_2_out)
        out_2 = self.con_1(CBL_5_1_out)

        cat_4 = torch.cat([down1_out, CBL2_out], dim=1)
        CBL_4_3_out = self.CBL_4_3(cat_4)
        CBL_5_2_out = self.CBL_5_2(CBL_4_3_out)
        out_3 = self.con_2(CBL_5_2_out)

        return out_1, out_2, out_3


if __name__ == '__main__':
    a = torch.rand(1, 3, 608, 608)
    # print(a)
    net = Mainnet()
    out1, out2, out3 = net(a)
    print(out1.shape, out2.shape, out3.shape)
    # print(out.shape)
    print(summary(net, (3,608, 608)))
    # out = net(a)
    # print(out.shape)

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值