def balanced_softmax_loss(labels, logits, sample_per_class, reduction):
"""Compute the Balanced Softmax Loss between `logits` and the ground truth `labels`.
Args:
labels: A int tensor of size [batch].
logits: A float tensor of size [batch, no_of_classes].
sample_per_class: A int tensor of size [no of classes].
reduction: string. One of "none", "mean", "sum"
Returns:
loss: A float tensor. Balanced Softmax Loss.
"""
spc = sample_per_class.type_as(logits)
spc = spc.unsqueeze(0).expand(logits.shape[0], -1)
logits = logits + spc.log()
loss = F.cross_entropy(input=logits, target=labels, reduction=reduction)
return loss
这样就可以更好的关注少数类别。少数类的log()值小,他的logits就得大。大概就是这个意思吧。文章还提到了一个元采样器,不懂。