-
原理:
-
实现代码,可直接使用
logits
和label
作为输入参数import torch import torch.nn.functional as F import torch.nn as nn class Focal_Loss(nn.Module): def __init__(self, weight, gamma=2): super(Focal_Loss,self).__init__() self.gamma = gamma self.weight = weight # 是tensor数据格式的列表 def forward(self, preds, labels): """ preds:logist输出值 labels:标签 """ preds = F.softmax(preds,dim=1) print(preds) eps = 1e-7 target = self.one_hot(preds.size(1), labels) print(target) ce = -1 * torch.log(preds+eps) * target print(ce) floss = torch.pow((1-preds), self.gamma) * ce print(floss) floss = torch.mul(floss, self.weight) print(floss) floss = torch.sum(floss, dim=1) print(floss) return torch.mean(floss) def one_hot(self, num, labels): one = torch.zeros((labels.size(0),num)) one[range(labels.size(0)),labels] = 1 return one
-
参数说明
-
初始化类时,需要传入 a 列表,类型为tensor,表示每个类别的样本占比的反比,比如5分类中,有某一类占比非常多,那么就设置为小于0.2,即相应的权重缩小,占比很小的类,相应的权重就要大于0.2
lf = Focal_Loss(torch.tensor([0.2,0.2,0.2,0.2,0.2]))
-
使用时,
logits
是神经网络的输出,不用计算softmax
,label是torchvision
类自动生成的标签loss = lf(logits,label)
-
-
例子,这里 logits 为(16*5)的tensor,表示批大小为16,5分类;label为每个样本的真实标签类别,对应 logits 的下标,是一个16维的tensor向量
logits = torch.tensor([[-2.7672, 3.6104, -7.4242, -3.2486, -3.1323], [-2.4270, 3.1833, -5.9394, -2.4592, -3.2292], [-2.5986, 3.3626, -6.7340, -2.8639, -3.1553], [-2.6206, 3.4201, -6.8754, -2.9308, -3.1507], [-2.8307, 3.7070, -7.6975, -3.3924, -3.1318], [-2.5776, 3.3316, -6.6595, -2.8187, -3.1542], [-2.8930, 3.7982, -7.9322, -3.5327, -3.1210], [-2.5489, 3.3580, -6.5229, -2.7590, -3.1912], [-1.5628, 1.8362, -1.8254, -0.3083, -3.5928], [ 0.2434, -4.9000, 1.1150, 2.7505, -1.0390], [-2.6877, 3.5686, -7.1178, -3.0847, -3.1617], [-2.6847, 3.5191, -6.8264, -3.0083, -3.2041], [-2.6137, 3.4025, -6.8965, -2.9250, -3.1396], [-2.7505, 3.5840, -7.3340, -3.2035, -3.1435], [-2.7030, 3.5163, -7.1549, -3.1002, -3.1424], [-2.6661, 3.4580, -7.0481, -3.0258, -3.1365]]) label = torch.tensor([1, 1, 1, 1, 1, 4, 1, 1, 1, 0, 1, 1, 3, 1, 1, 1]) lf = Focal_Loss(torch.tensor([0.2,0.2,0.2,0.2,0.2])) loss = lf(logits,label) print('loss:', loss)
输出结果
loss: tensor(0.1902)
Pytorch实现多分类问题样本不均衡的权重损失函数 FocusLoss
于 2022-04-16 10:48:44 首次发布