按照上图的架构重新自己写了分YOLOV4代码:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
# Mish = x*tanh(ln(1+e^x))
class Mish(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
x = x * (torch.tanh(F.softplus(x)))
return x
# CON + Mish + Batchnormal
class CMB(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
super().__init__()
self.mish = Mish()
self.CB = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
nn.BatchNorm2d(out_channels),
)
def forward(self, x):
CB_out = self.CB(x)
CMBout = self.mish(CB_out)
return CMBout
class ResidualLayer(nn.Module):
def __init__(self, in_channels):
super(ResidualLayer, self).__init__()
self.Resnet = nn.Sequential(
CMB(in_channels, in_channels // 2, 1, 1, 0),
CMB(in_channels // 2, in_channels, 3, 1, 1),
)
def forward(self, x):
return x + self.Resnet(x)
class ResidualLayer_1(nn.Module):
def __init__(self, in_channels):
super(ResidualLayer_1, self).__init__()
self.Resnet = nn.Sequential(
CMB(in_channels, in_channels, 1, 1, 0),
CMB(in_channels, in_channels, 3, 1, 1),
)
def forward(self, x):
return x + self.Resnet(x)
class CSPnet(nn.Module):
def __init__(self, in_channels, out_channel):
super().__init__()
self.CMB1 = CMB(in_channels, out_channel, 3, 2, 1)
self.Seq = nn.Sequential(CMB(out_channel, out_channel, 1, 1, 0),
ResidualLayer(64),
CMB(out_channel, out_channel, 1, 1, 0))
self.CMB2 = CMB(out_channel, out_channel, 1, 1, 0)
self.CMB = CMB(out_channel * 2, out_channel, 1, 1, 0)
def forward(self, x):
CMB1_out = self.CMB1(x)
Seq_out = self.Seq(CMB1_out)
CMB2_out = self.CMB2(CMB1_out)
Seq_cat_CMB2 = torch.cat([Seq_out, CMB2_out], dim=1)
CSP_out = self.CMB(Seq_cat_CMB2)
return CSP_out
class CSPnet_2(nn.Module):
def __init__(self):
super().__init__()
self.CMB1 = CMB(64, 128, 3, 2, 1)
self.Seq = nn.Sequential(CMB(128, 64, 1, 1, 0),
ResidualLayer_1(64),
ResidualLayer(64),
CMB(64, 64, 1, 1, 0))
self.CMB2 = CMB(128, 64, 1, 1, 0)
self.CMB = CMB(128, 128, 1, 1, 0)
def forward(self, x):
CMB1_out = self.CMB1(x)
Seq_out = self.Seq(CMB1_out)
CMB2_out = self.CMB2(CMB1_out)
Seq_cat_CMB2 = torch.cat([Seq_out, CMB2_out], dim=1)
CSPnet_2_out = self.CMB(Seq_cat_CMB2)
return CSPnet_2_out
class CSPnet_8(nn.Module):
def __init__(self, inchannel, outchannel): # 第一次128, 256 第二次256, 512
super().__init__()
self.CMB1 = CMB(inchannel, outchannel, 3, 2, 1)
self.Seq = nn.Sequential(CMB(outchannel, inchannel, 1, 1, 0),
ResidualLayer(inchannel),
ResidualLayer(inchannel),
ResidualLayer(inchannel),
ResidualLayer(inchannel),
ResidualLayer(inchannel),
ResidualLayer(inchannel),
ResidualLayer(inchannel),
ResidualLayer(inchannel),
CMB(inchannel, inchannel, 1, 1, 0))
self.CMB2 = CMB(outchannel, inchannel, 1, 1, 0)
self.CMB = CMB(outchannel, outchannel, 1, 1, 0)
def forward(self, x):
CMB1_out = self.CMB1(x)
Seq_out = self.Seq(CMB1_out)
CMB2_out = self.CMB2(CMB1_out)
Seq_cat_CMB2 = torch.cat([Seq_out, CMB2_out], dim=1)
CSPnet_8_out = self.CMB(Seq_cat_CMB2)
return CSPnet_8_out
class CSPnet_4(nn.Module):
def __init__(self, inchannel, outchannel): # 第一次128, 256 第二次256, 512
super().__init__()
self.CMB1 = CMB(inchannel, outchannel, 3, 2, 1)
self.Seq = nn.Sequential(CMB(outchannel, inchannel, 1, 1, 0),
ResidualLayer(inchannel),
ResidualLayer(inchannel),
ResidualLayer(inchannel),
ResidualLayer(inchannel),
CMB(inchannel, inchannel, 1, 1, 0))
self.CMB2 = CMB(outchannel, inchannel, 1, 1, 0)
self.CMB = CMB(outchannel, outchannel, 1, 1, 0)
def forward(self, x):
CMB1_out = self.CMB1(x)
Seq_out = self.Seq(CMB1_out)
CMB2_out = self.CMB2(CMB1_out)
Seq_cat_CMB2 = torch.cat([Seq_out, CMB2_out], dim=1)
CSPnet_4_out = self.CMB(Seq_cat_CMB2)
return CSPnet_4_out
class CBL(nn.Module):
def __init__(self, inchannel, outchannel, CBL):
super().__init__()
if CBL == 'once':
self.CBL = nn.Sequential(nn.Conv2d(inchannel, outchannel, 1, 1, 0),
nn.BatchNorm2d(outchannel),
nn.LeakyReLU(0.1),
nn.Conv2d(outchannel, inchannel, 3, 1, 1),
nn.BatchNorm2d(inchannel),
nn.LeakyReLU(0.1),
nn.Conv2d(inchannel, outchannel, 1, 1, 0),
nn.BatchNorm2d(outchannel),
nn.LeakyReLU(0.1)
)
elif CBL == 'second':
self.CBL = nn.Sequential(nn.Conv2d(inchannel, inchannel // 4, 1, 1, 0),
nn.BatchNorm2d(inchannel // 4),
nn.LeakyReLU(0.1),
nn.Conv2d(inchannel // 4, inchannel // 2, 3, 1, 1),
nn.BatchNorm2d(inchannel // 2),
nn.LeakyReLU(0.1),
nn.Conv2d(inchannel // 2, inchannel // 4, 1, 1, 0),
nn.BatchNorm2d(inchannel // 4),
nn.LeakyReLU(0.1), )
elif CBL == 'three':
self.CBL = nn.Sequential(nn.Conv2d(inchannel, inchannel // 2, 1, 1, 0),
nn.BatchNorm2d(inchannel // 2),
nn.LeakyReLU(0.1), )
elif CBL == 'four':
self.CBL = nn.Sequential(nn.Conv2d(inchannel, inchannel // 2, 1, 1, 0),
nn.BatchNorm2d(inchannel // 2),
nn.LeakyReLU(0.1),
nn.Conv2d(inchannel // 2, inchannel, 3, 1, 1),
nn.BatchNorm2d(inchannel),
nn.LeakyReLU(0.1),
nn.Conv2d(inchannel, inchannel // 2, 1, 1, 0),
nn.BatchNorm2d(inchannel // 2),
nn.LeakyReLU(0.1),
nn.Conv2d(inchannel // 2, inchannel, 3, 1, 1),
nn.BatchNorm2d(inchannel),
nn.LeakyReLU(0.1),
nn.Conv2d(inchannel, inchannel // 2, 1, 1, 0),
nn.BatchNorm2d(inchannel // 2),
nn.LeakyReLU(0.1), )
elif CBL == 'five':
self.CBL = nn.Sequential(nn.Conv2d(inchannel, inchannel * 2, 1, 1, 0),
nn.BatchNorm2d(inchannel * 2),
nn.LeakyReLU(0.1), )
def forward(self, x):
return self.CBL(x)
class SPP(nn.Module):
def __init__(self):
super().__init__()
self.maxpool1 = nn.MaxPool2d(kernel_size=5, stride=1, padding=5 // 2)
self.maxpool2 = nn.MaxPool2d(kernel_size=9, stride=1, padding=9 // 2)
self.maxpool3 = nn.MaxPool2d(kernel_size=13, stride=1, padding=13 // 2)
def forward(self, x):
maxpool1_out = self.maxpool1(x)
maxpool2_out = self.maxpool2(x)
maxpool3_out = self.maxpool3(x)
spp_out = torch.cat([maxpool1_out, maxpool2_out, maxpool3_out, x], dim=1)
return spp_out
class UpsampleLayer(torch.nn.Module):
def __init__(self):
super(UpsampleLayer, self).__init__()
def forward(self, x):
return F.interpolate(x, scale_factor=2, mode='nearest')
class Con(nn.Module):
def __init__(self, inchannel, outchannel):
super().__init__()
self.Con2d = nn.Conv2d(inchannel, outchannel, 3, 1, 1)
def forward(self, x):
out = self.Con2d(x)
return out
class downsample(nn.Module):
def __init__(self, inchannel, outchannel):
super().__init__()
self.Con2d = nn.Conv2d(inchannel, outchannel, 3, 2, 1)
def forward(self, x):
out = self.Con2d(x)
return out
class Mainnet(nn.Module):
def __init__(self):
super().__init__()
self.CMB = CMB(3, 32, 3, 1, 1)
self.CSP = CSPnet(32, 64)
self.CSP2 = CSPnet_2()
self.CSP8_1 = CSPnet_8(128, 256)
self.CSP8_2 = CSPnet_8(256, 512)
self.CSP4 = CSPnet_4(512, 1024)
self.CBL_1 = CBL(1024, 512, 'once')
self.Spp = SPP()
self.CBL_2 = CBL(2048, 512, 'second')
self.CBL_3 = CBL(512, 256, 'three')
self.up_1 = UpsampleLayer()
self.CBL_4 = CBL(512, 256, 'four')
self.CBL_3_1 = CBL(256, 128, 'three')
self.CBL_4_1 = CBL(256, 128, 'four')
self.CBL_5 = CBL(128, 256, 'five')
self.down = downsample(128, 256)
self.con = Con(256, 255)
self.CBL_4_2 = CBL(512, 256, 'four')
self.CBL_5_1 = CBL(256, 512, 'five')
self.down1 = downsample(256, 512)
self.con_1 = Con(512, 255)
self.CBL_4_3 = CBL(1024, 512, 'four')
self.CBL_5_2 = CBL(512, 1024, 'five')
self.down2 = downsample(512, 1024)
self.con_2 = Con(1024, 255)
def forward(self, x):
CMB_out = self.CMB(x)
# CSP
CSP1_out = self.CSP(CMB_out)
CSP2_out = self.CSP2(CSP1_out)
CSP8_out_1 = self.CSP8_1(CSP2_out)
CSP8_out_2 = self.CSP8_2(CSP8_out_1)
CSP4_out = self.CSP4(CSP8_out_2)
# CBL
CBL1_out = self.CBL_1(CSP4_out)
# SPP+CBL
Spp_out = self.Spp(CBL1_out)
CBL2_out = self.CBL_2(Spp_out)
# CBL + upsample
CBL3_out = self.CBL_3(CBL2_out)
up1_out = self.up_1(CBL3_out)
CBL_up_1 = self.CBL_3(CSP8_out_2)
up1_cat = torch.cat([up1_out, CBL_up_1], dim=1)
# print(up1_cat.shape)
# CBL*5 + up + cat
CBL_4 = self.CBL_4(up1_cat)
CBL_3_1_out = self.CBL_3_1(CBL_4)
up2_out = self.up_1(CBL_3_1_out)
CBL_up_2 = self.CBL_3_1(CSP8_out_1)
up1_cat = torch.cat([up2_out, CBL_up_2], dim=1)
CBL_4_1_out = self.CBL_4_1(up1_cat)
down_out = self.down(CBL_4_1_out)
# print(CBL_4_1_out.shape)
CBL_5 = self.CBL_5(CBL_4_1_out)
out_1 = self.con(CBL_5)
cat_3 = torch.cat([down_out,CBL_4], dim=1)
CBL_4_2_out = self.CBL_4_2(cat_3)
down1_out = self.down1(CBL_4_2_out)
CBL_5_1_out = self.CBL_5_1(CBL_4_2_out)
out_2 = self.con_1(CBL_5_1_out)
cat_4 = torch.cat([down1_out, CBL2_out], dim=1)
CBL_4_3_out = self.CBL_4_3(cat_4)
CBL_5_2_out = self.CBL_5_2(CBL_4_3_out)
out_3 = self.con_2(CBL_5_2_out)
return out_1, out_2, out_3
if __name__ == '__main__':
a = torch.rand(1, 3, 608, 608)
# print(a)
net = Mainnet()
out1, out2, out3 = net(a)
print(out1.shape, out2.shape, out3.shape)
# print(out.shape)
print(summary(net, (3,608, 608)))
# out = net(a)
# print(out.shape)