# !/usr/bin/python3
# -*- coding:utf-8 --*--
import torch.nn as nn
#定义卷积块
class Convolution_block(nn.Module):
def __init__(self,in_feature,out_feature,kernel,strike,padding=0):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_feature,out_feature,kernel,strike,padding),
nn.BatchNorm2d(out_feature),
nn.LeakyReLU()
)
def forward(self,x):
# print(self.conv(x).shape,x.shape)
return self.conv(x)
#定义下采样块
class Downsample(nn.Module):
def __init__(self,in_feature):
super().__init__()
self.downsample = nn.Conv2d(in_feature // 2,in_feature,3,2,1)# 6 - 3 + 2 / 2 + 1 =
def forward(self,x):
# print(self.downsample(x).shape)
return self.downsample(x)
#定义上采样层
class Upsample(nn.Module):
def __init__(self):
super().__init__()
def forward(self,x):
return nn.functional.interpolate(x, scale_factor=2, mode="nearest")
#定义残差块
class Residual_block(nn.Module):
def __init__(self,in_feature):
super().__init__()
self.conv = nn.Sequential(
Convolution_block(in_feature,in_feature // 2,1,1), #5 - 1 + 2 /1 + 1= 7
Convolution_block(in_feature // 2,in_feature//2,3,1,1),
Convolution_block(in_feature//2,in_feature,1,1)
)
def forward(self,x):
# print(x.shape,self.conv(x).shape)
return self.conv(x) + x
#定义卷积集
class Conv_set(nn.Module):
def __init__(self,in_featurn):
super().__init__()
self.conv1 =nn.Sequential(
Convolution_block(in_featurn,in_featurn //2,1,1),#降低参数量
Convolution_block(in_featurn // 2,in_featurn // 2,3,1,1),#卷积操作像素融合
Convolution_block(in_featurn // 2,in_featurn // 2,1,1),#通道融合
Convolution_block(in_featurn // 2,in_featurn // 2,3,1,1),
Convolution_block(in_featurn // 2 , in_featurn ,1,1)
)
def forward(self,x):
return self.conv1(x)
class Darknet53(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = Convolution_block(3,32,3,1,1)
self.down1 = Downsample(64)
self.resnet1 = Residual_block(64)
self.down2 = Downsample(128)
self.resnet2 = nn.Sequential(
Residual_block(128),
Residual_block(128)
)
self.down3 = Downsample(256)
self.resnet52 = nn.Sequential(
Residual_block(256),
Residual_block(256),
Residual_block(256),
Residual_block(256),
Residual_block(256),
Residual_block(256),
Residual_block(256),
Residual_block(256),
)
self.down4 = Downsample(512)
self.resnet26 = nn.Sequential(
Residual_block(512),
Residual_block(512),
Residual_block(512),
Residual_block(512),
Residual_block(512),
Residual_block(512),
Residual_block(512),
Residual_block(512),
Residual_block(512),
)
self.down5 = Downsample(1024)
self.resnet13 = nn.Sequential(
Residual_block(1024),
Residual_block(1024),
Residual_block(1024),
Residual_block(1024),
)
self.conv_set1 = Conv_set(1024)
self.conv13 = Convolution_block(1024,512,3,1,1)
self.conv_out13 = nn.Conv2d(512,45,1)
self.conv2 = Convolution_block(1024,512,1,1)
self.conv_set2 = Conv_set(1024)
self.conv26 = Convolution_block(1024,256,3,1,1)
self.conv_out26 = nn.Conv2d(256,45,1)
self.conv3 = Convolution_block(1024,256,1,1)
self.conv_set3 = Conv_set(512)
self.conv52 = Convolution_block(512,128,3,1,1)
self.conv_out52 = nn.Conv2d(128,45,1)
def forward(self,x):
conv1 = self.conv1(x)
down1 = self.down1(conv1)
resnet1 = self.resnet1(down1)
down2 = self.down2(resnet1)
resnet2 = self.resnet2(down2)
down3 = self.down3(resnet2)
resnet52 = self.resnet52(down3)
down4 = self.down4(resnet52)
resnet26 = self.resnet26(down4)
down5 = self.down5(resnet26)
resnet13 = self.resnet13(down5)
conv_set1 = self.conv_set1(resnet13)
conv13 = self.conv13(conv_set1)
conv_out13 = self.conv_out13(conv13)
conv2 = self.conv2(conv_set1)
#上采样
up1 = Upsample()(conv2)
# print(up1.shape,resnet26.shape)
concate1 = torch.cat((up1,resnet26),1)
conv_set2 = self.conv_set2(concate1)
print(conv_set2.shape,concate1.shape)
conv26 = self.conv26(conv_set2)
conv_out26 = self.conv_out26(conv26)
conv3 = self.conv3(conv_set2)
up2 = Upsample()(conv3)
concate2 = torch.cat((up2,resnet52),1)
conv_set3 = self.conv_set3(concate2)
conv52 = self.conv52(conv_set3)
conv_out52 = self.conv_out52(conv52)
return conv_out52,conv_out26,conv_out13
def get_parameter_number(net):
total_num = sum(p.numel() for p in net.parameters())
trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
return {'Total': total_num, 'Trainable': trainable_num}
if __name__ == '__main__':
import torch
x = torch.randn(1,3,416,416)
net = Darknet53()
a,b,c = net(x)
print(a.shape,b.shape,c.shape)
print(get_parameter_number(net))
YOLOv3-Darknet53
最新推荐文章于 2022-05-14 21:47:41 发布