velikodniy:I've implemented an analog of weighted_cross_entropy_with_logits in my current project. It's useful for working with imbalanced datasets. I want to add it to PyTorch but I'm in doubt if it is really needed for others.
For example, my implementation:
def weighted_binary_cross_entropy_with_logits(logits, targets, pos_weight, weight=None, size_average=True, reduce=True):
if not (targets.size() == logits.size()):
raise ValueError("Target size ({}) must be the same as input size ({})".format(targets.size(), logits.size()))
max_val = (-logits).clamp(min=0)
log_weight = 1 + (pos_weight - 1) * targets
loss = (1 - targets) * logits + log_weight * (((-max_val).exp() + (-logits - max_val).exp()).log() + max_val)
if weight is not None:
loss = loss * weight
if not reduce:
return loss
elif size_average:
return loss.mean()
else:
return loss.sum()
class WeightedBCEWithLogitsLoss(torch.nn.Module):
def __init__(self, pos_weight, weight=None, size_average=True, reduce=True):
super().__init__()
self.register_buffer('weight', weight)
self.register_buffer('pos_weight', pos_weight)
self.size_average = size_average
self.reduce = reduce
def forward(self, input, target):
pos_weight = Variable(self.pos_weight) if not isinstance(self.pos_weight, Variable) else self.pos_weight
if self.weight is not None:
weight = Variable(self.weight) if not isinstance(self.weight, Variable) else self.weight
return weighted_binary_cross_entropy_with_logits(input, target,
pos_weight,
weight=weight
size_average=self.size_average,
reduce=self.reduce)
else:
return weighted_binary_cross_entropy_with_logits(input, target,
pos_weight,
weight=None,
size_average=self.size_average,
reduce=self.reduce)
Note that pos_weight
is multiplied only by the first addend in the formula for BCE loss. It's not the weight for the whole target. I cannot see the simple way to do it beforehand.
Proposed loss looks like:
pos_weight * targets * -log(sigmoid(logits)) + (1 - targets) * -log(1 - sigmoid(logits))
BCEWithLogitsLoss
with multiplier m
for targets:
m * targets * -log(sigmoid(logits)) + (1 - m * targets) * -log(1 - sigmoid(logits))
If these formulae are the same, m
should be equal:
(log(1 - sigmoid(logits)) - pos_weight * log(sigmoid(logits)))/(log(1 - sigmoid(logits)) - log(sigmoid(logits)))
This formula is quite complex, contains logits
, and I guess it is numerically unstable.
Maybe I didn't understand your idea.
I follow @velikodniy to add the Weighted BCEloss, where the weights can be computed dynamically for each batch:
def weighted_binary_cross_entropy(sigmoid_x, targets, pos_weight, weight=None, size_average=True, reduce=True):
"""
Args:
sigmoid_x: predicted probability of size [N,C], N sample and C Class. Eg. Must be in range of [0,1], i.e. Output from Sigmoid.
targets: true value, one-hot-like vector of size [N,C]
pos_weight: Weight for postive sample
"""
if not (targets.size() == sigmoid_x.size()):
raise ValueError("Target size ({}) must be the same as input size ({})".format(targets.size(), sigmoid_x.size()))
loss = -pos_weight* targets * sigmoid_x.log() - (1-targets)*(1-sigmoid_x).log()
if weight is not None:
loss = loss * weight
if not reduce:
return loss
elif size_average:
return loss.mean()
else:
return loss.sum()
class WeightedBCELoss(Module):
def __init__(self, pos_weight=1, weight=None, PosWeightIsDynamic= False, WeightIsDynamic= False, size_average=True, reduce=True):
"""
Args:
pos_weight = Weight for postive samples. Size [1,C]
weight = Weight for Each class. Size [1,C]
PosWeightIsDynamic: If True, the pos_weight is computed on each batch. If pos_weight is None, then it remains None.
WeightIsDynamic: If True, the weight is computed on each batch. If weight is None, then it remains None.
"""
super().__init__()
self.register_buffer('weight', weight)
self.register_buffer('pos_weight', pos_weight)
self.size_average = size_average
self.reduce = reduce
self.PosWeightIsDynamic = PosWeightIsDynamic
def forward(self, input, target):
# pos_weight = Variable(self.pos_weight) if not isinstance(self.pos_weight, Variable) else self.pos_weight
if self.PosWeightIsDynamic:
positive_counts = target.sum(dim=0)
nBatch = len(target)
self.pos_weight = (nBatch - positive_counts)/(positive_counts +1e-5)
if self.weight is not None:
# weight = Variable(self.weight) if not isinstance(self.weight, Variable) else self.weight
return weighted_binary_cross_entropy(input, target,
self.pos_weight,
weight=self.weight,
size_average=self.size_average,
reduce=self.reduce)
else:
return weighted_binary_cross_entropy(input, target,
self.pos_weight,
weight=None,
size_average=self.size_average,
reduce=self.reduce)