(1)核心代码(nn.modules.shufflenetv2.py)
import torch
import torch.nn as nn
class Conv_maxpool(nn.Module):
def __init__(self, c1, c2): # ch_in, ch_out
super().__init__()
self.conv= nn.Sequential(
nn.Conv2d(c1, c2, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(c2),
nn.ReLU(inplace=True),
)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
def forward(self, x):
return self.maxpool(self.conv(x))
class ShuffleNetV2(nn.Module):
def __init__(self, inp, oup, stride): # ch_in, ch_out, stride
super().__init__()
self.stride = stride
branch_features = oup // 2
assert (self.stride != 1) or (inp == branch_features << 1)
if self.stride == 2:
# copy input
self.branch1 = nn.Sequential(
nn.Conv2d(inp, inp, kernel_size=3, stride=self.stride, padding=1, groups=inp),
nn.BatchNorm2d(inp),
nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True))
else:
self.branch1 = nn.Sequential()
self.branch2 = nn.Sequential(
nn.Conv2d(inp if (self.stride == 2) else branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
nn.Conv2d(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1, groups=branch_features),
nn.BatchNorm2d(branch_features),
nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
)
def forward(self, x):
if self.stride == 1:
x1, x2 = x.chunk(2, dim=1)
out = torch.cat((x1, self.branch2(x2)), dim=1)
else:
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
out = self.channel_shuffle(out, 2)
return out
def channel_shuffle(self, x, groups):
N, C, H, W = x.size()
out = x.view(N, groups, C // groups, H, W).permute(0, 2, 1, 3, 4).contiguous().view(N, C, H, W)
return out
(2)task.py引用
from ultralytics.nn.modules.shufflenetv2 import *
执行规则
elif m in [ShuffleNetV2, Conv_maxpool]:
c1, c2 = ch[f], args[0]
if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
c2 = make_divisible(c2 * width, 8)
args = [c1, c2, *args[1:]]
(3)yaml文件
backbone: #3*640*640
# [from, repeats, module, args]
- [-1, 1, Conv_maxpool, [24]] # 0-P2/4 #24*160*160
- [-1, 1, ShuffleNetV2, [116, 2]] # 1-P3/8 #channel,步距
- [-1, 3, ShuffleNetV2, [116, 1]] # 2
- [-1, 1, ShuffleNetV2, [232, 2]] # 3-P4/16
- [-1, 7, ShuffleNetV2, [232, 1]] # 4
- [-1, 1, ShuffleNetV2, [464, 2]] # 5-P5/32
- [-1, 3, ShuffleNetV2, [464, 1]] # 6
- [-1, 1, SPPF, [1024, 5]] # 7
- [-1, 2, C2PSA, [1024]] # 8
# YOLOv8.0n head
head:
- [-1, 1, nn.Upsample, [None, 2, 'nearest']] #9
- [[-1, 4], 1, Concat, [1]] # cat backbone P4 10
- [-1, 2, C3k2, [512, False]] # 11
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
- [[-1, 2], 1, Concat, [1]] # cat backbone P3
- [-1, 2, C3k2, [256, False]] # 14 (P3/8-small)
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 11], 1, Concat, [1]] # cat head P4
- [-1, 2, C3k2, [512, False]] # 17 (P4/16-medium)
- [-1, 1, Conv, [512, 3, 2]]
- [[-1, 8], 1, Concat, [1]] # cat head P5
- [-1, 2, C3k2, [1024, True]] # 20 (P5/32-large)
- [[14, 17, 20], 1, Detect, [nc]] # Detect(P3, P4, P5)