1.原理
(1)交叉熵损失函数:
a.二分类的形式
b.一般形式(经过化简,pt为标签类别在模型预测的概率)
交叉熵损失函数在类别不均衡的情况下,使得模型对稀有类别样本的预测效果较差,focal loss意在动态赋予各个类别样本的权重。
(2).focalLoss
可以理解每类样本的权重为:
其中gamma为调节因子,大于零。
2.代码demo
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
def __init__(self,gamma = 2,weight = None,ignore_index = -100):
super(FocalLoss,self).__init__()
self.gamma = gamma
self.weight = weight
self.ignore_index = ignore_index
def forward(self,input_,target):
# input_:[num,num_tags]
# target:[num]
# logpt:[num,num_tags]
logpt = F.log_softmax(input_,dim = 1)
# pt:[num,num_tags]
pt = torch.exp(logpt)
# logpt:[num,num_tags]
logpt = (1 - pt)**self.gamma * logpt
loss = F.nll_loss(logpt,target,self.weight,ignore_index = self.ignore_index)
return loss
loss = FocalLoss()
input_ = torch.randn([5,10])
target = torch.ones([5]).long()
print('Focal loss:',loss(input_,target))
输出:
Focal loss: tensor(2.1576)