原理
从17年被RetinaNet提出,Focal Loss 一直备受好评。由于其着重关注分类较差的样本的思想,Focal loss以简单的形式,一定程度解决了样本的难例挖掘,样本不均衡的问题。
普通的Cross Entropy
C
E
(
p
t
)
=
−
a
t
l
o
g
(
p
t
)
CE(p_t) = -a_t log(p_t)
CE(pt)=−atlog(pt)
a
t
a_t
at是平衡因子。
Focal Loss
F
L
(
p
t
)
=
−
(
1
−
p
t
)
r
l
o
g
(
p
t
)
FL(p_t) = -(1-p_t)^rlog(p_t)
FL(pt)=−(1−pt)rlog(pt)
在log前面加上
(
1
−
p
t
)
(1-p_t)
(1−pt)是focal loss的核心。假设
r
r
r设置为2。当
p
t
p_t
pt为0.9,说明网络已经将这个样本分的很好了,那么
(
1
−
p
t
)
2
(1-p_t)^2
(1−pt)2 为0.01,呈指数级降低了这个样本对损失函数的贡献。当
p
t
p_t
pt为0.1,说明网络对样本还不具有很好地分类能力,那么
(
1
−
p
t
)
2
(1-p_t)^2
(1−pt)2为0.81。 简单言之,focal加大了对难分类样本的关注。
代码
来自知乎大佬
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
class FocalLoss(nn.Module):
def __init__(self, class_num, alpha=None, gamma=2, size_average=True):
super(FocalLoss, self).__init__()
if alpha is None: # alpha 是平衡因子
self.alpha = Variable(torch.ones(class_num, 1))
else:
if isinstance(alpha, Variable):
self.alpha = alpha
else:
self.alpha = Variable(alpha)
self.gamma = gamma # 指数
self.class_num = class_num # 类别数目
self.size_average = size_average # 返回的loss是否需要mean一下
def forward(self, inputs, targets):
# target : N, 1, H, W
inputs = inputs.permute(0, 2, 3, 1)
targets = targets.permute(0, 2, 3, 1)
num, h, w, C = inputs.size()
N = num * h * w
inputs = inputs.reshape(N, -1) # N, C
targets = targets.reshape(N, -1) # 待转换为one hot label
P = F.softmax(inputs, dim=1) # 先求p_t
class_mask = inputs.data.new(N, C).fill_(0)
class_mask = Variable(class_mask)
ids = targets.view(-1, 1)
class_mask.scatter_(1, ids.data, 1.) # 得到label的one_hot编码
if inputs.is_cuda and not self.alpha.is_cuda:
self.alpha = self.alpha.cuda() # 如果是多GPU训练 这里的cuda要指定搬运到指定GPU上 分布式多进程训练除外
alpha = self.alpha[ids.data.view(-1)]
# y*p_t 如果这里不用*, 还可以用gather提取出正确分到的类别概率。
# 之所以能用sum,是因为class_mask已经把预测错误的概率清零了。
probs = (P * class_mask).sum(1).view(-1, 1)
# y*log(p_t)
log_p = probs.log()
# -a * (1-p_t)^2 * log(p_t)
batch_loss = -alpha * (torch.pow((1 - probs), self.gamma)) * log_p
if self.size_average:
loss = batch_loss.mean()
else:
loss = batch_loss.sum()
return loss
在代码我写了清晰的注释。该Focal loss可适用于大于2类的分类任务。