YOLOX-s修改特征融合部分(替换YOLOPAFPN)
参考bubbliiiing大佬的代码
def conv2d(filter_in, filter_out, kernel_size, stride=1):
pad = (kernel_size - 1) // 2 if kernel_size else 0
return nn.Sequential(OrderedDict([
("conv", nn.Conv2d(filter_in, filter_out, kernel_size=kernel_size, stride=stride, padding=pad, bias=False)),
("bn", nn.BatchNorm2d(filter_out)),
("relu", nn.LeakyReLU(0.1)),
]))
class Upsample(nn.Module):
def __init__(self, in_channels, out_channels):
super(Upsample, self).__init__()
self.upsample = nn.Sequential(
conv2d(in_channels, out_channels, 1),
nn.Upsample(scale_factor=2, mode='nearest')
)
def forward(self, x,):
x = self.upsample(x)
return x
class YOLOPAFPN(nn.Module):
def __init__(self, depth=1.0, width=1.0, in_features=("dark3", "dark4", "dark5"), in_channels=[128, 256, 512, 1024, 2048, 4096],
depthwise=False, act="silu", epsilon=1e-4):
super().__init__()
Conv = DWConv if depthwise else BaseConv
self.backbone = CSPDarknet(depth, width, depthwise=depthwise, act=act)
self.in_features = in_features
self.epsilon = epsilon
self.swish = Swish()
# 对输入进来的p5进行宽高的下采样 20,20,1024 --> 10, 10, 2048
self.p5_to_p6 = conv2d(in_channels[2], in_channels[3], kernel_size=3, stride=2)
# 对p6进行宽高的下采样 10, 10, 2048 --> 5, 5, 4096
self.p6_to_p7 = conv2d(in_channels[3], in_channels[4], kernel_size=3, stride=2)
# 简易注意力机制的weights
self.p6_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
self.p6_w1_relu = nn.ReLU()
self.p5_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
self.p5_w1_relu = nn.ReLU()
self.p4_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
self.p4_w1_relu = nn.ReLU()
self.p3_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
self.p3_w1_relu = nn.ReLU()
self.p4_w2 = nn.Parameter(torch.ones(3, dtype=torch.float32), requires_grad=True)
self.p4_w2_relu = nn.ReLU()
self.p5_w2 = nn.Parameter(torch.ones(3, dtype=torch.float32), requires_grad=True)
self.p5_w2_relu = nn.ReLU()
self.p6_w2 = nn.Parameter(torch.ones(3, dtype=torch.float32), requires_grad=True)
self.p6_w2_relu = nn.ReLU()
self.p7_w2 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
self.p7_w2_relu = nn.ReLU()
self.p6_upsample = Upsample(in_channels[4], in_channels[3])
self.conv6_up = conv2d(in_channels[3], in_channels[3], kernel_size=3)
self.p5_upsample = Upsample(in_channels[3], in_channels[2])
self.conv5_up = conv2d(in_channels[2], in_channels[2], kernel_size=3)
self.p4_upsample = Upsample(in_channels[2], in_channels[1])
self.conv4_up = conv2d(in_channels[1], in_channels[1], kernel_size=3)
self.p3_upsample = Upsample(in_channels[1], in_channels[0])
self.conv3_up = conv2d(in_channels[0], in_channels[0], kernel_size=3)
self.p4_downsample = conv2d(in_channels[0], in_channels[1], kernel_size=3, stride=2)
self.conv4_down = conv2d(in_channels[1], in_channels[1], kernel_size=3)
self.p5_downsample = conv2d(in_channels[1], in_channels[2], kernel_size=3, stride=2)
self.conv5_down = conv2d(in_channels[2], in_channels[2], kernel_size=3)
self.p6_downsample = conv2d(in_channels[2], in_channels[3], kernel_size=3, stride=2)
self.conv6_down = conv2d(in_channels[3], in_channels[3], kernel_size=3)
self.p7_downsample = conv2d(in_channels[3], in_channels[4], kernel_size=3, stride=2)
self.conv7_down = conv2d(in_channels[4], in_channels[4], kernel_size=3)
def forward(self, input):
out_features = self.backbone.forward(input)
[feat1, feat2, feat3] = [out_features[f] for f in self.in_features]
p3_in, p4_in, p5_in = feat1, feat2, feat3
p6_in = self.p5_to_p6(p5_in)
p7_in = self.p6_to_p7(p6_in)
# 简单的注意力机制,用于确定更关注p7_in还是p6_in
p6_w1 = self.p6_w1_relu(self.p6_w1)
weight = p6_w1 / (torch.sum(p6_w1, dim=0) + self.epsilon)
# p6_td 10, 10, 2048
p6_td = self.conv6_up(self.swish(weight[0] * p6_in + weight[1] * self.p6_upsample(p7_in)))
# 简单的注意力机制,用于确定更关注p6_td还是p5_in
p5_w1 = self.p5_w1_relu(self.p5_w1)
weight = p5_w1 / (torch.sum(p5_w1, dim=0) + self.epsilon)
# p5_td 20, 20, 1024
p5_td = self.conv5_up(self.swish(weight[0] * p5_in + weight[1] * self.p5_upsample(p6_td)))
# 简单的注意力机制,用于确定更关注p5_td还是p4_in
p4_w1 = self.p4_w1_relu(self.p4_w1)
weight = p4_w1 / (torch.sum(p4_w1, dim=0) + self.epsilon)
# p4_td 40, 40, 512
p4_td = self.conv4_up(self.swish(weight[0] * p4_in + weight[1] * self.p4_upsample(p5_td)))
# 简单的注意力机制,用于确定更关注p4_td还是p3_in
p3_w1 = self.p3_w1_relu(self.p3_w1)
weight = p3_w1 / (torch.sum(p3_w1, dim=0) + self.epsilon)
# p3_out 80, 80, 256
p3_out = self.conv3_up(self.swish(weight[0] * p3_in + weight[1] * self.p3_upsample(p4_td)))
# 简单的注意力机制,用于确定更关注p4_in还是p4_td还是p3_out
p4_w2 = self.p4_w2_relu(self.p4_w2)
weight = p4_w2 / (torch.sum(p4_w2, dim=0) + self.epsilon)
# p4_out 40,40,512
p4_out = self.conv4_down(
self.swish(weight[0] * p4_in + weight[1] * p4_td + weight[2] * self.p4_downsample(p3_out)))
# 简单的注意力机制,用于确定更关注p5_in_2还是p5_up还是p4_out
p5_w2 = self.p5_w2_relu(self.p5_w2)
weight = p5_w2 / (torch.sum(p5_w2, dim=0) + self.epsilon)
# p5_out 20, 20, 1024
p5_out = self.conv5_down(
self.swish(weight[0] * p5_in + weight[1] * p5_td + weight[2] * self.p5_downsample(p4_out)))
# 简单的注意力机制,用于确定更关注p6_in还是p6_up还是p5_out
p6_w2 = self.p6_w2_relu(self.p6_w2)
weight = p6_w2 / (torch.sum(p6_w2, dim=0) + self.epsilon)
# p6_out 10, 10, 2048
p6_out = self.conv6_down(
self.swish(weight[0] * p6_in + weight[1] * p6_td + weight[2] * self.p6_downsample(p5_out)))
# 简单的注意力机制,用于确定更关注p7_in还是p7_up还是p6_out
p7_w2 = self.p7_w2_relu(self.p7_w2)
weight = p7_w2 / (torch.sum(p7_w2, dim=0) + self.epsilon)
# p7_out 5, 5, 4096
p7_out = self.conv7_down(self.swish(weight[0] * p7_in + weight[1] * self.p7_downsample(p6_out)))
return p3_out, p4_out, p5_out, p6_out, p7_out