重点:label: B H W C (C是类别) B*H*W C
input 要改为 B H W C,然后B*H*W,C两维
sigmoid focalloss 不包含背景类,cls_num通道数减一
softmax focalloss 包含背景类
def loss(self,
label_cls,
cls_prob):
num_imgs = cls_prob.size(0)
num_pos = torch.sum((label_cls != 0)).item()
flatten_cls_prob = cls_prob.permute(0, 2, 3, 1).reshape(-1,self.cls_num)
flatten_label_cls = label_cls.reshape(-1)
loss_cls = self.loss_cls(
pred=flatten_cls_prob,target=flatten_label_cls,avg_factor=num_pos + num_imgs)
return loss_cls