fc-head更适合分类任务,conv-head更适合定位任务
class DecoupledHead(nn.Module):
# 代码是参考啥都会一点的老程大佬的 https://blog.csdn.net/weixin_44119362
def __init__(self, ch=256, nc=80, width=1.0, anchors=()):
super().__init__()
self.nc = nc # number of classes
self.nl = len(anchors) # number of detection layers 3
self.na = len(anchors[0]) // 2 # number of anchors 3
self.merge = Conv(ch, 256 * width, 1, 1)
self.cls_convs1 = Conv(256 * width, 256 * width, 3, 1, 1)
self.cls_convs2 = Conv(256 * width, 256 * width, 3, 1, 1)
self.reg_convs1 = Conv(256 * width, 256 * width, 3, 1, 1)
self.reg_convs2 = Conv(256 * width, 256 * width, 3, 1, 1)
self.cls_preds = nn.Conv2d(256 * width, self.nc * self.na, 1)
self.reg_preds = nn.Conv2d(256 * width, 4 * self.na, 1)
self.obj_preds = nn.Conv2d(256 * width, 1 * self.na, 1)
def forward(self, x):
x = self.merge(x)
# 分类=3x3conv + 3x3conv + 1x1convpred
x1 = self.cls_convs1(x)
x1 = self.cls_convs2(x1)
x1 = self.cls_preds(x1)
# 回归=3x3conv(共享) + 3x3conv(共享) + 1x1pred
x2 = self.reg_convs1(x)
x2 = self.reg_convs2(x2)
x21 = self.reg_preds(x2)
# 置信度=3x3conv(共享)+ 3x3conv(共享) + 1x1pred
x22 = self.obj_preds(x2)
out = torch.cat([x21, x22, x1], 1)
return out
摘自csdn