Focal Loss for Dense Object Detection
通过对不同样本的loss进行加权,从而达到聚焦于学习困难样本的方法,该方法普适性很强。
Key words : Sample balance、Hard example、Focusing parameter
Subjects: Computer Vision and Pattern Recognition (cs.CV)
ICCV2017
作者:RBG和Kaiming
Agile Pioneer
交叉熵的计算形式如下:
FocalLoss定义如下:
One-stage的目标检测算中存在正负样本不均衡的情况,以及困难样本难以分类的情况。
F L ( p t ) = − α t ( 1 − p t ) γ l o g ( p t ) FL(p_t) = -\alpha_t(1-p_t)^{\gamma}log(p_t) FL(pt)=−αt(1−pt)γlog(pt)
核心思想
1. 解决正负样本不均衡问题 α ∈ ( 0 , 1 ) \alpha \in (0, 1) α∈(0,1)
-
1.1 α \alpha α 是调节由正负样本产生loss的参数,比如一个batch的8个样本中有7个是负样本,那么负样本loss的占比就很高,可以通过这个参数来对负样本产生的loss进行制约,相对提高正样本的loss
-
1.2 所以在使用的时候 α \alpha α 向量是一个长度和batch_size相同的向量,然后根据类别索引来确定正负样本,来生成对应的loss权重向量[ α \alpha α, …, 1 − α 1 - \alpha 1−α,…]
-
1.3 多类别的时候怎么办呢,可以考虑不同样本的占比作为权重
-
1.4 对于训练集合中样本比例较大的类别乘以 α \alpha α(<0.5),样本比例较小的类别1- α \alpha α,是正常的逻辑,但是在Focal loss 中是正样本乘以 α \alpha α(<0.5),本来正样本难以“匹敌”负样本,但经过下面介绍的 γ \gamma γ 的“操控”后,也许形势还逆转了,还要对正样本降权,已达到更好的效果
2. 解决困难样本学习问题 γ ∈ [ 0 , + i n f ) \gamma \in [0, +inf ) γ∈[0,+inf)
-
2.1 γ \gamma γ 称为 “focusing parameter”,目的是通过减少易分类样本的权重(而不是增加困难样本的权重,这样hard example的权重相对就提升了很多),从而使得模型在训练时更专注于难分类的样本
-
2.2 当一个样本被分错的时候,p是很小的,那么 γ = ( 1 − P ) \gamma=(1-P) γ=(1−P)接近1,损失不被影响
-
2.3 当P→1,因子 γ = ( 1 − P ) \gamma=(1-P) γ=(1−P)接近0,那么分的比较好的(well-classified)样本的权值就被调低了。
-
2.4 当 γ = 0 \gamma=0 γ=0的时候,focal loss就是传统的交叉熵损失,当 γ \gamma γ增加的时候,调制系数也会增加。 专注参数 γ \gamma γ平滑地调节了易分样本调低权值的比例。 γ \gamma γ增大能增强调制因子的影响,实验发现 γ \gamma γ取2最好。
-
2.5 这里的 p t p_t pt是预测的onehot结果中对应真实类别的概率。
直觉上来说,调制因子减少了易分样本的损失贡献,拓宽了样例接收到低损失的范围。当 γ \gamma γ一定的时候,比如等于2,一样easy example(pt=0.9)的loss要比标准的交叉熵loss小100+倍,当pt=0.968时,要小1000+倍,但是对于hard example(pt < 0.5),loss最多小了4倍。这样的话hard example的权重相对就提升了很多。这样就增加了那些误分类的重要性。
作者建议的最佳值:alpha取值为0.25, gamma=2
Focal loss
p t p_t pt - positive prob
F o c a l _ l o s s = { − α ( 1 − p t ) γ l o g ( p t ) y = 1 ( 1 − α ) ( p t ) γ l o g ( 1 − p t ) y = 0 Focal\_loss=\begin{cases} -\alpha (1 - p_t)^{\gamma}log(p_t) & y = 1 \\ (1-\alpha) (p_t)^{\gamma}log(1 - p_t) & y = 0 \\ \end{cases} Focal_loss={−α(1−pt)γlog(pt)(1−α)(pt)γlog(1−pt)y=1y=0
问下自己这些问题:
Q:Focal loss是如何计算的,例如p=0.9或p=0.5?
Q:Focal loss如何应用在多分类中, α \alpha α 和 γ \gamma γ 如何取值?
Q:Focal loss中 1 − p t 1-p_t 1−pt中的 p t p_t pt对应哪个类别的概率?
这是我基于pytorch写的二分类的LableSmooth结合FocalLoss的代码
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
class FocalLossWithLabelSmooth(nn.Module):
def __init__(self,class_num, alpha=0.25, gamma=2, eps=0.06):
super(FocalLossWithLabelSmooth, self).__init__()
self.class_num = class_num
# for label smooth
self.eps = eps
# for focal loss
self.alpha = alpha
self.gamma = gamma
def forward(self, pred, label):
label = label.contiguous().view(-1)
one_hot_label = torch.zeros_like(pred)
one_hot_label = one_hot_label.scatter(1, label.view(-1, 1), 1)
one_hot_label = one_hot_label * (1 - self.eps) + (1 - one_hot_label) * self.eps / (self.class_num - 1)
# for label smooth
log_prob = F.log_softmax(pred, dim=1)
CEloss = (one_hot_label * log_prob).sum(dim=1)
#print(one_hot_label)
# for focal loss
P = F.softmax(pred, 1)
class_mask = pred.data.new(pred.size(0), pred.size(1)).fill_(0)
class_mask = Variable(class_mask)
ids = label.view(-1, 1)
class_mask.scatter_(1, ids.data, 1.)
probs = (P * class_mask).sum(1).view(-1, 1)
#print(probs)
# if multi-class you need to modify here
alpha = torch.empty(label.size()).fill_(1 - self.alpha)
# TODO: multi class
alpha[label == 1] = self.alpha
if pred.is_cuda and not alpha.is_cuda:
alpha = alpha.cuda()
batch_loss = -alpha * (torch.pow((1 - probs), self.gamma)) * CEloss
loss = batch_loss.mean()
return loss
参考:
[1] https://zhuanlan.zhihu.com/p/49981234
[2] https://blog.csdn.net/LeeWanzhi/article/details/80069592
[3] https://blog.csdn.net/weixin_44638957/article/details/100733971