一、entroy_loss
entroy_loss已经讨论了很多了,具体可以看我的这篇博客
二、focal_loss
说到focal_loss,一定要看的是这张图:
从上图可以看到,focal_loss只比entroy_loss多了一个权重
α
(
1
−
p
t
)
γ
\alpha(1-p_t)^{\gamma}
α(1−pt)γ 当
p
t
p_t
pt越大时,赋予的权重越小,focal_loss能够降低简单样本的loss,让网络更偏重与比较难的样本。
下面是focal_loss的代码
class FocalLoss(nn.Module):
def __init__(self, gamma=0, eps=1e-7):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.eps = eps
self.ce = torch.nn.CrossEntropyLoss()
def forward(self, input, target):
logp = self.ce(input, target)
p = torch.exp(-logp)
loss = (1 - p) ** self.gamma * logp
return loss.mean()