代码如下
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
def _squeeze_binary_labels(label):
if label.size(1) == 1:
squeeze_label = label.view(len(label), -1)
else:
inds = torch.nonzero(label >= 1).squeeze()
squeeze_label = inds[:,-1]
return squeeze_label
def cross_entropy(pred, label, weight=None, reduction='mean', avg_factor=None):
# element-wise losses
if label.size(-1) != pred.size(0):
label = _squeeze_binary_labels(label)
loss = F.cross_entropy(pred, label, reduction='none')
# apply weights and do the reduction
if weight is not None:
weight = weight.float()
loss = weight_reduce_loss(
loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
return loss
def _expand_binary_labels(labels, label_weights, label_channels):
bin_labels = labels.new_full((labels.size(0), label_channels), 0)
inds = torch.nonzero(labels >= 1).squeeze()
if inds.numel() > 0:
bin_labels[inds, labels[inds] - 1] = 1
if label_weights is None:
bin_label_weights = None
else:
bin_label_weights = label_weights.view(-1, 1).expand(
label_weights.size(0), label_channels)
return bin_labels, bin_label_weights
def binary_cross_entropy(pred,
label,
weight=None,
reduction='mean',
avg_factor=None):
if pred.dim() != label.dim():
label, weight = _expand_binary_labels(label, weight, pred.size(-1))
# weighted element-wise losses
if weight is not None:
weight = weight.float()
loss = F.binary_cross_entropy_with_logits(
pred, label.float(), weight, reduction='none')
loss = weight_reduce_loss(loss, reduction=reduction, avg_factor=avg_factor)
return loss
def partial_cross_entropy(pred,
label,
weight=None,
reduction='mean',
avg_factor=None):
if pred.dim() != label.dim():
label, weight = _expand_binary_labels(label, weight, pred.size(-1))
# weighted element-wise losses
if weight is not None:
weight = weight.float()
mask = label == -1
loss = F.binary_cross_entropy_with_logits(
pred, label.float(), weight, reduction='none')
if mask.sum() > 0:
loss *= (1-mask).float()
avg_factor = (1-mask).float().sum()
# do the reduction for the weighted loss
loss = weight_reduce_loss(loss, reduction=reduction, avg_factor=avg_factor)
return loss
def kpos_cross_entropy(pred, label, weight=None, reduction='mean', avg_factor=None):
# element-wise losses
if pred.dim() != label.dim():
label, weight = _expand_binary_labels(label, weight, pred.size(-1))
target = label.float() / torch.sum(label, dim=1, keepdim=True).float()
loss = - target * F.log_softmax(pred, dim=1)
# apply weights and do the reduction
if weight is not None:
weight = weight.float()
loss = weight_reduce_loss(
loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
return loss
class ResampleLoss(nn.Module):
def __init__(self,
use_sigmoid=True, partial=False,
loss_weight=1.0, reduction='mean',
reweight_func=None, # None, 'inv', 'sqrt_inv', 'rebalance', 'CB'
weight_norm=None, # None, 'by_instance', 'by_batch'
focal=dict(ocal=True, alpha=0.5, gamma=2),
map_param=dict(alpha=10.0, beta=0.2, gamma=0.1),
CB_loss=dict(CB_beta=0.9,CB_mode='average_w'), # 'by_class', 'average_n', 'average_w', 'min_n'
logit_reg=dict(neg_scale=5.0, init_bias=0.1),
class_freq=None,
train_num=None, device=None):
super(ResampleLoss, self).__init__()
self.device = device
assert (use_sigmoid is True) or (partial is False)
self.use_sigmoid = use_sigmoid
self.partial = partial
self.loss_weight = loss_weight
self.reduction = reduction
if self.use_sigmoid:
if self.partial:
self.cls_criterion = partial_cross_entropy
else:
self.cls_criterion = binary_cross_entropy
else:
self.cls_criterion = cross_entropy
# reweighting function
self.reweight_func = reweight_func
# normalization (optional)
self.weight_norm = weight_norm
# focal loss params
self.focal = focal['focal']
self.gamma = focal['gamma']
self.alpha = focal['alpha'] # change to alpha
# mapping function params
self.map_alpha = map_param['alpha']
self.map_beta = map_param['beta']
self.map_gamma = map_param['gamma']
# CB loss params (optional)
self.CB_beta = CB_loss['CB_beta']
self.CB_mode = CB_loss['CB_mode']
self.class_freq = torch.from_numpy(np.asarray(class_freq)).float().to(self.device)
self.num_classes = self.class_freq.shape[0]
self.train_num = train_num # only used to be divided by class_freq
# regularization params
self.logit_reg = logit_reg
self.neg_scale = logit_reg[
'neg_scale'] if 'neg_scale' in logit_reg else 1.0
init_bias = logit_reg['init_bias'] if 'init_bias' in logit_reg else 0.0
self.init_bias = - torch.log(
self.train_num / self.class_freq - 1) * init_bias ########################## bug fixed https://github.com/wutong16/DistributionBalancedLoss/issues/8
self.freq_inv = torch.ones(self.class_freq.shape).to(self.device) / self.class_freq
self.propotion_inv = self.train_num / self.class_freq
def forward(self,
cls_score,
label,
weight=None,
avg_factor=None,
reduction_override=None,
**kwargs):
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
weight = self.reweight_functions(label)
cls_score, weight = self.logit_reg_functions(label.float(), cls_score, weight)
if self.focal:
logpt = self.cls_criterion(
cls_score.clone(), label, weight=None, reduction='none',
avg_factor=avg_factor)
# pt is sigmoid(logit) for pos or sigmoid(-logit) for neg
pt = torch.exp(-logpt)
wtloss = self.cls_criterion(
cls_score, label.float(), weight=weight, reduction='none')
alpha_t = torch.where(label == 1, torch.Tensor([self.alpha]), torch.Tensor([1-self.alpha])) # alpha_t是干嘛?
loss = alpha_t * ((1 - pt) ** self.gamma) * wtloss # balance_param should be a tensor
loss = reduce_loss(loss, reduction) # add reduction
else:
loss = self.cls_criterion(cls_score, label.float(), weight,
reduction=reduction)
loss = self.loss_weight * loss
return loss
def reweight_functions(self, label):
if self.reweight_func is None:
return None
elif self.reweight_func in ['inv', 'sqrt_inv']:
weight = self.RW_weight(label.float())
elif self.reweight_func in 'rebalance':
weight = self.rebalance_weight(label.float())
elif self.reweight_func in 'CB':
weight = self.CB_weight(label.float())
else:
return None
if self.weight_norm is not None:
if 'by_instance' in self.weight_norm:
max_by_instance, _ = torch.max(weight, dim=-1, keepdim=True)
weight = weight / max_by_instance
elif 'by_batch' in self.weight_norm:
weight = weight / torch.max(weight)
return weight
def logit_reg_functions(self, labels, logits, weight=None):
# logit做正则化也就是对负样本重新加权
if not self.logit_reg:
return logits, weight
if 'init_bias' in self.logit_reg:
logits += self.init_bias
if 'neg_scale' in self.logit_reg:
logits = logits * (1 - labels) * self.neg_scale + logits * labels
'# 负样本的值扩大两倍 正样本值不变'
if weight is not None:
weight = weight / self.neg_scale * (1 - labels) + weight * labels
'负样本权重除以2,正样本权重不变'
return logits, weight
def rebalance_weight(self, gt_labels):
repeat_rate = torch.sum(gt_labels.float() * self.freq_inv, dim=1, keepdim=True)
'所有正样本的freq_inv之和'
pos_weight = self.freq_inv.clone().detach().unsqueeze(0) / repeat_rate
'将所有正样本的值归一化到【0,1】, 所有负样本权重对应减小值'
# pos and neg are equally treated
weight = torch.sigmoid(self.map_beta * (pos_weight - self.map_gamma)) + self.map_alpha
return weight
def CB_weight(self, gt_labels):
if 'by_class' in self.CB_mode:
weight = torch.tensor((1 - self.CB_beta)).to(self.device) / (1 - torch.pow(self.CB_beta, self.class_freq)).to(self.device)
elif 'average_n' in self.CB_mode:
avg_n = torch.sum(gt_labels * self.class_freq, dim=1, keepdim=True) / \
torch.sum(gt_labels, dim=1, keepdim=True)
weight = torch.tensor((1 - self.CB_beta)).to(self.device) / \
(1 - torch.pow(self.CB_beta, avg_n)).to(self.device)
elif 'average_w' in self.CB_mode:
weight_ = torch.tensor((1 - self.CB_beta)).to(self.device) / \
(1 - torch.pow(self.CB_beta, self.class_freq)).to(self.device)
weight = torch.sum(gt_labels * weight_, dim=1, keepdim=True) / \
torch.sum(gt_labels, dim=1, keepdim=True)
elif 'min_n' in self.CB_mode:
min_n, _ = torch.min(gt_labels * self.class_freq +
(1 - gt_labels) * 100000, dim=1, keepdim=True)
weight = torch.tensor((1 - self.CB_beta)).to(self.device) / \
(1 - torch.pow(self.CB_beta, min_n)).to(self.device)
else:
raise NameError
return weight
def RW_weight(self, gt_labels, by_class=True):
if 'sqrt' in self.reweight_func:
weight = torch.sqrt(self.propotion_inv)
else:
weight = self.propotion_inv
if not by_class:
sum_ = torch.sum(weight * gt_labels, dim=1, keepdim=True)
weight = sum_ / torch.sum(gt_labels, dim=1, keepdim=True)
return weight
def reduce_loss(loss, reduction):
"""Reduce loss as specified.
Args:
loss (Tensor): Elementwise loss tensor.
reduction (str): Options are "none", "mean" and "sum".
Return:
Tensor: Reduced loss tensor.
"""
reduction_enum = F._Reduction.get_enum(reduction)
# none: 0, elementwise_mean:1, sum: 2
if reduction_enum == 0:
return loss
elif reduction_enum == 1:
return loss.mean()
elif reduction_enum == 2:
return loss.sum()
def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
"""Apply element-wise weight and reduce loss.
Args:
loss (Tensor): Element-wise loss.
weight (Tensor): Element-wise weights.
reduction (str): Same as built-in losses of PyTorch.
avg_factor (float): Avarage factor when computing the mean of losses.
Returns:
Tensor: Processed loss values.
"""
# if weight is specified, apply element-wise weight
if weight is not None:
loss = loss * weight
# if avg_factor is not specified, just reduce the loss
if avg_factor is None:
loss = reduce_loss(loss, reduction)
else:
# if reduction is mean, then average the loss by avg_factor
if reduction == 'mean':
loss = loss.sum() / avg_factor
# if reduction is 'none', then do nothing, otherwise raise an error
elif reduction != 'none':
raise ValueError('avg_factor can not be used with reduction="sum"')
return loss
def binary_cross_entropy(pred,
label,
weight=None,
reduction='mean',
avg_factor=None):
# weighted element-wise losses
if weight is not None:
weight = weight.float()
loss = F.binary_cross_entropy_with_logits(
pred, label.float(), weight, reduction='none')
loss = weight_reduce_loss(loss, reduction=reduction, avg_factor=avg_factor)
return loss
整体思路:
1. rebalance_weight函数主要是对当前batch所有正样本的值归一化到【0,1】, 所有负样本的值对应增加(除以一个小于1的数)
2. logit_reg_functions函数主要是对logits 负样本位置的值扩大两倍 且正样本位置的值不变,负样本权重除以2,正样本权重不变
3. 针对forward增加self.focal部分,pt为不加权部分得到的损失取对数后的值,wtloss是加权后得到的损失,然后经过平滑alpha_t * ((1 - pt) ** self.gamma)*wtloss得到损失loss的平均值