Focal loss
class MultiClassFocalLossWithAlpha(nn.Module):
def __init__(self, alpha=[1.-8434/28760, 1.-8069/28760, 1.-578/28760, 1.-3903/28760, 1.-7806/28760], gamma=2, reduction='mean'):
"""
:param alpha: 权重系数列表,三分类中第0类权重0.2,第1类权重0.3,第2类权重0.5
:param gamma: 困难样本挖掘的gamma
:param reduction:
"""
super(MultiClassFocalLossWithAlpha, self).__init__()
self.alpha = torch.tensor(alpha)
self.gamma = gamma
self.reduction = reduction
def forward(self, pred, target):
alpha = torch.index_select(self.alpha.to(target.device), 0, target.view(-1)) # 为当前batch内的样本,逐个分配类别权重,shape=(bs), 一维向量
log_softmax = torch.log_softmax(pred, dim=1) # 对模型裸输出做softmax再取log, shape=(bs, 3)
logpt = torch.gather(log_softmax, dim=1, index=target.view(-1, 1)) # 取出每个样本在类别标签位置的log_softmax值, shape=(bs, 1)
logpt = logpt.view(-1) # 降维,shape=(bs)
ce_loss = -logpt # 对log_softmax再取负,就是交叉熵了
pt = torch.exp(logpt) #对log_softmax取exp,把log消了,就是每个样本在类别标签位置的softmax值了,shape=(bs)
focal_loss = alpha * (1 - pt) ** self.gamma * ce_loss # 根据公式计算focal loss,得到每个样本的loss值,shape=(bs)
if self.reduction == "mean":
return torch.mean(focal_loss)
if self.reduction == "sum":
return torch.sum(focal_loss)
return focal_loss
E focal loss
class EqualizedFocalLossCW:
def __init__(self, alpha=[1.-8434/28760, 1.-8069/28760, 1.-578/28760, 1.-3903/28760, 1.-7806/28760], gamma_b=2, scale_factor=8, reduction="mean"):
self.gamma_b = gamma_b
self.scale_factor = scale_factor
self.reduction = reduction
self.alpha = alpha
def __call__(self, logits, targets):
ce_loss = F.cross_entropy(logits, targets, reduction="none")
outputs = F.cross_entropy(logits, targets) # 求导使用,不能带 reduction 参数
log_pt = -ce_loss
pt = torch.exp(log_pt) # softmax 函数打分
targets = targets.view(-1, 1) # 多加一个维度,为使用 gather 函数做准备
grad_i = torch.autograd.grad(outputs=-outputs, inputs=logits)[0] # 求导
grad_i = grad_i.gather(1, targets) # 每个类对应的梯度
pos_grad_i = F.relu(grad_i).sum()
neg_grad_i = F.relu(-grad_i).sum()
neg_grad_i += 1e-9 # 防止除数为0
grad_i = pos_grad_i / neg_grad_i
grad_i = torch.clamp(grad_i, min=0, max=1) # 裁剪梯度
dy_gamma = self.gamma_b + self.scale_factor * (1 - grad_i)
dy_gamma = dy_gamma.view(-1) # 去掉多的一个维度
# weighting factor
wf = dy_gamma / self.gamma_b
weights = wf * (1 - pt) ** dy_gamma
if self.alpha is not None:
# Apply class-specific weights
alpha_tensor = torch.tensor(self.alpha, device=logits.device)
alpha_weights = alpha_tensor[targets.view(-1)]
weights = weights * alpha_weights
efl = weights * ce_loss
if self.reduction == "sum":
efl = efl.sum()
elif self.reduction == "mean":
efl = efl.mean()
else:
raise ValueError(f"reduction '{self.reduction}' is not valid")
return efl