浅浅记录一下自己修改过的网络。
一.common.py 最后添加下列:
###############----- ShuffleNetV2主干网络-----####### def channel_shuffle(x: Tensor, groups: int) -> Tensor: batchsize, num_channels, height, width = x.size() channels_per_group = num_channels // groups # reshape x = x.view(batchsize, groups, channels_per_group, height, width) x = torch.transpose(x, 1, 2).contiguous() # flatten x = x.view(batchsize, -1, height, width) return x class conv_bn_relu_maxpool(nn.Module): def __init__(self, c1, c2): # ch_in, ch_out super(conv_bn_relu_maxpool, self).__init__() self.conv = nn.Sequential( nn.Conv2d(c1, c2, kernel_size=3, stride=2, padding=1, bias=False), nn.BatchNorm2d(c2), nn.LeakyReLU(inplace=True), # MemoryEfficientMish(), ) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) def forward(self, x): return self.maxpool(self.conv(x)) class ShuffleNetV2_InvertedResidual(nn.Module): def __init__( self, inp: int, oup: int, stride: int ) -> None: super(ShuffleNetV2_InvertedResidual, self).__init__() if not (1 <= stride <= 3): raise ValueError('illegal stride value') self.stride = stride branch_features = oup // 2 assert (self.stride != 1) or (inp == branch_features << 1) if self.stride > 1: self.branch1 = nn.Sequential( # self.depthwise_conv(inp, inp, kernel_size=3, stride=1, padding=1),###加的 self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1), nn.BatchNorm2d(inp), nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False), nn.BatchNorm2d(branch_features), # MemoryEfficientMish(), nn.ReLU(inplace=True), ) else: self.branch1 = nn.Sequential() self.branch2 = nn.Sequential( nn.Conv2d(inp if (self.stride > 1) else branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False), nn.BatchNorm2d(branch_features), # MemoryEfficientMish(), nn.ReLU(inplace=True), # self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=1, padding=1), self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),##加的 nn.BatchNorm2d(branch_features), # nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False), nn.BatchNorm2d(branch_features), # # Mish(), nn.ReLU(inplace=True), ) @staticmethod def depthwise_conv( i: int, o: int, kernel_size: int, stride: int = 1, padding: int = 0, bias: bool = False ) -> nn.Conv2d: return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i) def forward(self, x: Tensor) -> Tensor: if self.stride == 1: x1, x2 = x.chunk(2, dim=1) out = torch.cat((x1, self.branch2(x2)), dim=1) else: out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) out = channel_shuffle(out, 2) return out ####################-------ENDShuffleNetV2------##################
二.找到yolo.py的def parse_model(d, ch) 添加ShuffleNetV2_InvertedResidual
三.重新创建一个.yaml文件,下面是我根据自己的网络创建的
nc: 1 # number of classes depth_multiple: 1.0 # model depth multiple width_multiple: 1.0 # layer channel multiple anchors: - [10,13, 16,30, 33,23] # P3/8 - [30,61, 62,45, 59,119] # P4/16 - [116,90, 156,198, 373,326] # P5/32 # YOLOv5 v6.0 backbone # custom backbone backbone: # [from, number, module, args] [[-1, 1, Focus, [64, 3]], # 0-P2/4 [-1, 1, ShuffleNetV2_InvertedResidual, [128, 2]], # 1-P3/8 [-1, 3, ShuffleNetV2_InvertedResidual, [128, 1]], # 2 [-1, 1, ShuffleNetV2_InvertedResidual, [256, 2]], # 3-P4/16 [-1, 7, ShuffleNetV2_InvertedResidual, [256, 1]], # 4 [-1, 1, ShuffleNetV2_InvertedResidual, [512, 2]], # 5-P5/32 [-1, 3, ShuffleNetV2_InvertedResidual, [512, 1]], # 6 ] # YOLOv5 head head: [[-1, 1, Conv, [512, 1, 1]], # 7 [-1, 1, nn.Upsample, [None, 2, 'nearest']], # 8 [[-1, 4], 1, Concat, [1]], # cat backbone P4 # 9 [-1, 1, C3, [512, False]], # 10 [-1, 1, Conv, [256, 1, 1]], # 11 [-1, 1, nn.Upsample, [None, 2, 'nearest']], # 12 [[-1, 2], 1, Concat, [1]], # cat backbone P3 # 13 [-1, 1, C3, [256, False]], # 14 (P3/8-small) # 14 [-1, 1, Conv, [128, 3, 2]], # 15 [[-1, 11], 1, Concat, [1]], # cat head P4 # 16 [-1, 1, C3, [256, False]], # 17 (P4/16-medium) # 17 [-1, 1, Conv, [256, 3, 2]], # 18 [[-1, 7], 1, Concat, [1]], # cat head P5 # 19 [-1, 1, C3, [512, False]], # 20 (P5/32-large) # 20 [[14, 17, 20], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) ]
四.修改train.py的参数:主要就是修改cfg里对应的yaml路径
最后运行train.py 就可以啦!
注意!!!: 替换轻量化的网络虽然能够使网络运行的更快了,但是精度大概率会降,所以就是替换的玩玩可以。