cb loss pytorch 实现,可直接调用
参考:https://github.com/vandit15/Class-balanced-loss-pytorch/blob/master/class_balanced_loss.py
import numpy as np
import torch
import torch.nn.functional as F
def focal_loss(logits, labels, alpha, gamma):
"""Compute the focal loss between `logits` and the ground truth `labels`.
Focal loss = -alpha_t * (1-pt)^gamma * log(pt)
where pt is the probability of being classified to the true class.
pt = p (if true class), otherwise pt = 1 - p. p = sigmoid(logit).
Args:
logits: A float tensor of size [batch, num_classes].
labels: A float tensor of size [batch, num_classes].
alpha: A float tensor of size [batch_size]
specifying per-example weight for balanced cross entropy.
gamma: A float scalar modulating loss from hard and easy examples.
Returns:
focal_loss: A float32 scalar representing normalized total loss.
"""
bce_loss = F.binary_cross_entropy_with_logits(input=logits, target=labels, reduction="none")
if gamma == 0.0:
modulator = 1.0
else:
modulator = torch.exp(-gamma * labels * logits - gamma * torch.log(1 + torch.exp(-1.0 * logits)))
loss = modulator