RetinaNet+focal loss

one stage 精度不高,一个主要原因是正负样本的不平衡,以YOLO为例,每个grid cell有5个预测,本来正负样本的数量就有差距,再相当于进行5倍放大后,这种数量上的差异更会被放大。

文中提出新的分类损失函数Focal loss,该损失函数通过抑制那些容易分类样本的权重,将注意力集中在那些难以区分的样本上,有效控制正负样本比例,防止失衡现象。也就是focal loss用于解决正负样本不平衡与难易样本不平衡的问题.

其中用于控制正负样本的数量失衡,用于控制简单/难区分样本数量失衡。一般=0.25,=2.也就是正样本loss相对增加,负样本loss相对减少,负样本相比正样本loss减少的倍数为3,同时困难样本loss相对增加,简单样本loss相对减少.

模型采用FPN,P3到P7,其中P7能够增加对大物体的检测。

在FPN的P3-P7中分别设置32x32-512x512尺寸不等的anchor,比例设置为{1:2, 1:1, 2:1}。每一层一共有9个anchor,不同层能覆盖的size范围为32-813。对每一个anchor,都对应一个K维的one-hot向量(K是类别数)和4维的位置回归向量。

同时分类子网对A个anchor,每个anchor中的K个类别,都预测一个存在概率。如下图所示,对于FPN的每一层输出,对分类子网来说,加上四层3x3x256卷积的FCN网络,最后一层的卷积稍有不同,用3x3xKA,最后一层维度变为KA表示,对于每个anchor,都是一个K维向量,表示每一类的概率,然后因为one-hot属性,选取概率得分最高的设为1,其余k-1为归0。传统的RPN在分类子网用的是1x1x18,只有一层,而在RetinaNet中,用的是更深的卷积,总共有5层,实验证明,这种卷积层的加深,对结果有帮助。与分类子网并行,对每一层FPN输出接上一个位置回归子网,该子网本质也是FCN网络,预测的是anchor和它对应的一个GT位置的偏移量。首先也是4层256维卷积,最后一层是4A维度,即对每一个anchor,回归一个(x,y,w,h)四维向量。注意,此时的位置回归是类别无关的。分类和回归子网虽然是相似的结构,但是参数是不共享的

代码:

正负样本计算loss的两种方式


import torch
import torch.nn.functional as F

def focal_loss_one(alpha, beta, cls_preds, gts):
    print('======第一种实现方式=======')
    num_pos = gts.sum()
    print('==num_pos:', num_pos)
    alpha_tensor = torch.ones_like(cls_preds) * alpha
    alpha_tensor = torch.where(torch.eq(gts, 1.), alpha_tensor, 1. - alpha_tensor)
    print('===alpha_tensor===', alpha_tensor)
    preds = torch.where(torch.eq(gts, 1.), cls_preds, 1. - cls_preds)
    print('===1. - preds===', 1. - preds)
    focal_weight = alpha_tensor * torch.pow((1. - preds), beta)
    print('==focal_weight:', focal_weight)
    batch_bce_loss = -(gts * torch.log(cls_preds) + (1. - gts) * torch.log(1. - cls_preds))
    batch_focal_loss = focal_weight * batch_bce_loss
    print('==batch_focal_loss:', batch_focal_loss)
    batch_focal_loss = batch_focal_loss.sum()
    print('== batch_focal_loss:', batch_focal_loss)
    print('==batch_focal_loss.item():', batch_focal_loss.item())
    if num_pos != 0:
        mean_batch_focal_loss = batch_focal_loss / num_pos
    else:
        mean_batch_focal_loss = batch_focal_loss
    print('==mean_batch_focal_loss:', mean_batch_focal_loss)


def focal_loss_two(alpha, beta, cls_preds, gts):
    print('======第二种实现方式=======')
    pos_inds = (gts == 1.0).float()
    print('==pos_inds:', pos_inds)
    neg_inds = (gts != 1.0).float()
    print('===neg_inds:', neg_inds)
    pos_loss = -pos_inds * alpha * (1.0 - cls_preds) ** beta * torch.log(cls_preds)
    neg_loss = -neg_inds * (1 - alpha) * ((cls_preds) ** beta) * torch.log(1.0 - cls_preds)
    num_pos = pos_inds.float().sum()
    print('==num_pos:', num_pos)
    pos_loss = pos_loss.sum()
    neg_loss = neg_loss.sum()
    if num_pos == 0:
        mean_batch_focal_loss = neg_loss
    else:
        mean_batch_focal_loss = (pos_loss + neg_loss) / num_pos
    print('==mean_batch_focal_loss:', mean_batch_focal_loss)

def focal_loss_three(alpha, beta, cls_preds, gts):
    print('======第三种实现方式=======')
    num_pos = gts.sum()
    pred_sigmoid = cls_preds
    target = gts.type_as(pred_sigmoid)
    pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
    focal_weight = (alpha * target + (1 - alpha) *
                    (1 - target)) * pt.pow(beta)
    batch_focal_loss = F.binary_cross_entropy(
        pred_sigmoid, target, reduction='none') * focal_weight
    batch_focal_loss = batch_focal_loss.sum()
    if num_pos != 0:
        mean_batch_focal_loss = batch_focal_loss / num_pos
    else:
        mean_batch_focal_loss = batch_focal_loss
    print('==mean_batch_focal_loss:', mean_batch_focal_loss)
bs = 2
num_class = 3
alpha = 0.25
beta = 2
# (B, cls)
cls_preds = torch.rand([bs, num_class], dtype=torch.float)
print('==cls_preds:', cls_preds)
gts = torch.tensor([0, 2])
# (B, cls)
gts = F.one_hot(gts, num_classes=num_class).type_as(cls_preds)
print('===gts===', gts)
focal_loss_one(alpha, beta, cls_preds, gts)
focal_loss_two(alpha, beta, cls_preds, gts)
focal_loss_three(alpha, beta, cls_preds, gts)

只有正样本计算loss:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

class FocalLoss(nn.Module):
        """
        This criterion is a implemenation of Focal Loss, which is proposed in
        Focal Loss for Dense Object Detection.

            Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])

        The losses are averaged across observations for each minibatch.

        Args:
            alpha(1D Tensor, Variable) : the scalar factor for this criterion
            gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5),
                                   putting more focus on hard, misclassified examples
            size_average(bool): By default, the losses are averaged over observations for each minibatch.
                                However, if the field size_average is set to False, the losses are
                                instead summed for each minibatch.


    """
    def __init__(self, class_num, alpha=None, gamma=2, size_average=True):
        super(FocalLoss, self).__init__()
        if alpha is None:
            self.alpha = Variable(torch.ones(class_num, 1))
        else:
            if isinstance(alpha, Variable):
                self.alpha = alpha
            else:
                self.alpha = Variable(alpha)
        self.gamma = gamma
        self.class_num = class_num
        self.size_average = size_average

    def forward(self, inputs, targets):
        N = inputs.size(0)
        C = inputs.size(1)
        P = F.softmax(inputs, dim=-1)
        print('===P:', P)
        #.data 获取variable的tensor
        class_mask = inputs.data.new(N, C).fill_(0)
        class_mask = Variable(class_mask)
        ids = targets.view(-1, 1)
        class_mask.scatter_(1, ids.data, 1.)#得到onehot
        print('==class_mask:', class_mask)
        if inputs.is_cuda and not self.alpha.is_cuda:
            self.alpha = self.alpha.cuda()
        alpha = self.alpha[ids.data.view(-1)]
        print('==alpha:', alpha)
        probs = (P*class_mask).sum(1).view(-1, 1)
        print('==probs:', probs)
        log_p = probs.log()

        batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p

        if self.size_average:
            loss = batch_loss.mean()
        else:
            loss = batch_loss.sum()
        return loss

def debug_focal():
    import numpy as np
    #只对困难样本计算loss
    loss = FocalLoss(class_num=8)#, alpha=torch.tensor([0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25]).reshape(-1, 1))
    inputs = torch.rand(2, 8)
    print('==inputs:', inputs)
    # print('==inputs.data:', inputs.data)
    # targets = torch.from_numpy(np.array([[1,0,0,0,0,0,0,0],
    #                                      [0,1,0,0,0,0,0,0]]))
    targets = torch.from_numpy(np.array([0, 1]))
    cost = loss(inputs, targets)
    print('===cost===:', cost)

if __name__ == '__main__':
    debug_focal()

  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值