解决样本不均衡问题
重采样
label_id_level_1 = int(example.label_level_1)
label_id_level_2 = int(example.label_level_2)
samp_weight = math.sqrt(1 /label2freq_level_2[label_list_level_2[label_id_level_2]])
sample_weights.append(samp_weight)
Pytorch中使用weightedRandomSampler进行样本出现概率的设置,提升小类出现的概率
def train(self):
if self.args.use_weighted_sampler:
train_sampler = WeightedRandomSampler(
self.train_sample_weights,
len(self.train_sample_weights),
)
else:
train_sampler = RandomSampler(self.train_dataset)
train_dataloader = DataLoader(
self.train_dataset,
sampler=train_sampler,
batch_size=self.args.train_batch_size
重加权
# class weights
class_weights_level_1 = []
for i, lab in enumerate(label_list_level_1):
class_weights_level_1.append(label2freq_level_1[lab])
class_weights_level_1 = [1/w for w in class_weights_level_1]
if self.args.use_weighted_sampler:
class_weights_level_1 = [math.sqrt(w) for w in class_weights_level_1]
else:
class_weights_level_1 = [w for w in class_weights_level_1]
print("class_weights_level_1: ", class_weights_level_1)
self.class_weights_level_1 = F.softmax(torch.FloatTensor(
class_weights_level_1
).to(self.args.device))
# 1. loss
if label_ids_level_2 is not None:
if self.args.use_focal_loss:
loss_fct = FocalLoss(
self.num_labels_level_2,
alpha=self.class_weights_level_2,
gamma=self.args.focal_loss_gamma,
size_average=True
)
elif self.args.use_class_weights:
loss_fct = nn.CrossEntropyLoss(weight=self.class_weights_level_2)
else:
loss_fct = nn.CrossEntropyLoss()
loss_level_2 = loss_fct(
logits_level_2.view(-1, self.num_labels_level_2),
label_ids_level_2.view(-1)
)
outputs = (loss_level_2,) + outputs
return outputs
Focal loss
class FocalLoss(nn.Module):
r"""
This criterion is a implemenation of Focal Loss, which is proposed in
Focal Loss for Dense Object Detection.
Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])
The losses are averaged across observations for each minibatch.
Args:
alpha(1D Tensor, Variable) : the scalar factor for this criterion
gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5),
putting more focus on hard, misclassified examples
size_average(bool): By default, the losses are averaged over observations for each minibatch.
However, if the field size_average is set to False, the losses are
instead summed for each minibatch.
"""
def __init__(self, class_num, alpha=None, gamma=2, size_average=True, device=None):
super(FocalLoss, self).__init__()
if alpha is None:
self.alpha = torch.ones(class_num, 1).to(device)
else:
self.alpha = alpha.to(device)
self.gamma = gamma
self.class_num = class_num
self.size_average = size_average
def forward(self, inputs, targets):
N = inputs.size(0)
C = inputs.size(1)
P = F.softmax(inputs)
class_mask = torch.zeros_like(inputs).to(inputs.device)
ids = targets.view(-1, 1)
class_mask.scatter_(1, ids.data, 1.)
# print("class_mask: ", class_mask)
if inputs.is_cuda and not self.alpha.is_cuda:
self.alpha = self.alpha.cuda()
alpha = self.alpha[ids.data.view(-1)]
# print("alpha: ", alpha)
probs = (P * class_mask).sum(1).view(-1, 1)
# print("probs: ", probs)
log_p = probs.log()
# print('log_p size= {}'.format(log_p.size()))
# print(log_p)
batch_loss = -alpha * (torch.pow((1 - probs), self.gamma)) * log_p
# print('-----bacth_loss------')
# print(batch_loss)
if self.size_average:
loss = batch_loss.mean()
else:
loss = batch_loss.sum()
return loss
from typing import List, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
"""
Focal loss(https://arxiv.org/pdf/1708.02002.pdf)
Shape:
- input: (N, C)
- target: (N)
- Output: Scalar loss
Examples:
>>> loss = FocalLoss(gamma=2, alpha=[1.0]*7)
>>> input = torch.randn(3, 7, requires_grad=True)
>>> target = torch.empty(3, dtype=torch.long).random_(7)
>>> output = loss(input, target)
>>> output.backward()
"""
def __init__(self, gamma=0, alpha=None, reduction="none"):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
if alpha is not None:
if isinstance(alpha, list):
self.alpha = torch.FloatTensor(alpha)
else:
self.alpha = alpha
self.reduction = reduction
def forward(self, input, target):
'''
- input: (N, C), logits
- target: (N)
- Output: Scalar loss
Parameters
----------
input
target
Returns
-------
'''
# [N, 1]
target = target.unsqueeze(-1)
# [N, C]
pt = F.softmax(input, dim=-1)
logpt = F.log_softmax(input, dim=-1)
# 得到答案标签所获得 概率值 和 对数概率值
# [N]
pt = pt.gather(1, target).squeeze(-1)
logpt = logpt.gather(1, target).squeeze(-1)
# 加上 class weights
if self.alpha is not None:
# [N] at[i] = alpha[target[i]]
# 得到每个样本应该得到的 class weights
at = self.alpha.gather(0, target.squeeze(-1))
logpt = logpt * at.to(logpt.device)
loss = -1 * (1 - pt) ** self.gamma * logpt
if self.reduction == "none":
return loss
if self.reduction == "mean":
return loss.mean()
return loss.sum()
@staticmethod
def convert_binary_pred_to_two_dimension(x, is_logits=True):
"""
Args:
x: (*): (log) prob of some instance has label 1
is_logits: if True, x represents log prob; otherwhise presents prob
Returns:
y: (*, 2), where y[*, 1] == log prob of some instance has label 0,
y[*, 0] = log prob of some instance has label 1
"""
probs = torch.sigmoid(x) if is_logits else x
probs = probs.unsqueeze(-1)
probs = torch.cat([1-probs, probs], dim=-1)
logprob = torch.log(probs+1e-4) # 1e-4 to prevent being rounded to 0 in fp16
return logprob
def __str__(self):
return f"Focal Loss gamma:{self.gamma}"
def __repr__(self):
return str(self)
Dice Loss
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from typing import Optional
class DiceLoss(nn.Module):
"""
Dice coefficient for short, is an F1-oriented statistic
used to gauge the similarity of two sets.
Given two sets A and B, the vanilla dice coefficient
between them is given as follows:
Dice(A, B) = 2 * True_Positive / (2 * True_Positive + False_Positive + False_Negative)
= 2 * |A and B| / (|A| + |B|)
Math Function:
U-NET: https://arxiv.org/abs/1505.04597.pdf
dice_loss(p, y) = 1 - numerator / denominator
numerator = 2 * \sum_{1}^{t} p_i * y_i + smooth
denominator = \sum_{1}^{t} p_i + \sum_{1} ^{t} y_i + smooth
if square_denominator is True, the denominator is
\sum_{1}^{t} (p_i ** 2) + \sum_{1} ^{t} (y_i ** 2) + smooth
V-NET: https://arxiv.org/abs/1606.04797.pdf
Args:
smooth (float, optional): a manual smooth value for numerator and denominator.
square_denominator (bool, optional): [True, False], specifies whether to square the denominator in the loss function.
with_logits (bool, optional): [True, False], specifies whether the
input tensor is normalized by Sigmoid/Softmax funcs.
ohem_ratio: OHEM(online hard example miniing)
max ratio of positive/negative, defautls to 0.0, which means no ohem.
alpha: dsc alpha
Shape:
- input: (*)
- target: (*)
- mask: (*) 0,1 mask for the input sequence.
- Output: Scalar loss
Examples:
>>> loss = DiceLoss(with_logits=True, ohem_ratio=0.1)
>>> input = torch.FloatTensor([2, 1, 2, 2, 1])
>>> input.requires_grad=True
>>> target = torch.LongTensor([0, 1, 0, 0, 0])
>>> output = loss(input, target)
>>> output.backward()
"""
def __init__(self,
smooth: Optional[float] = 1e-4,
square_denominator: Optional[bool] = False,
with_logits: Optional[bool] = True,
ohem_ratio: float = 0.0,
alpha: float = 0.0,
reduction: Optional[str] = "mean",
index_label_position=True) -> None:
super(DiceLoss, self).__init__()
self.reduction = reduction
self.with_logits = with_logits
self.smooth = smooth
self.square_denominator = square_denominator
self.ohem_ratio = ohem_ratio
self.alpha = alpha
self.index_label_position = index_label_position
def forward(self, input: Tensor, target: Tensor, mask: Optional[Tensor] = None) -> Tensor:
logits_size = input.shape[-1]
if logits_size != 1:
loss = self._multiple_class(input, target, logits_size, mask=mask)
else:
loss = self._binary_class(input, target, mask=mask)
if self.reduction == "mean":
return loss.mean()
if self.reduction == "sum":
return loss.sum()
return loss
def _compute_dice_loss(self, flat_input, flat_target):
flat_input = ((1 - flat_input) ** self.alpha) * flat_input
interection = torch.sum(flat_input * flat_target, -1)
if not self.square_denominator:
loss = 1 - ((2 * interection + self.smooth) /
(flat_input.sum() + flat_target.sum() + self.smooth))
else:
loss = 1 - ((2 * interection + self.smooth) /
(torch.sum(torch.square(flat_input, ), -1) + torch.sum(torch.square(flat_target), -1) + self.smooth))
return loss
def _multiple_class(self, input, target, logits_size, mask=None):
# input: [N, C]
flat_input = input
flat_input = torch.nn.Softmax(dim=1)(flat_input) if self.with_logits else flat_input
# [N, ] --> [N, C]
flat_target = F.one_hot(target, num_classes=logits_size).float() \
if self.index_label_position else target.float()
if mask is not None:
mask = mask.float()
flat_input = flat_input * mask
flat_target = flat_target * mask
else:
mask = torch.ones_like(target)
loss = None
if self.ohem_ratio > 0:
mask_neg = torch.logical_not(mask)
for label_idx in range(logits_size):
# logits_size: 类别数
# pos_example: 标签为 label_idx的是正例;
# neg_example: 标签不为 label_idx的是负例;
pos_example = target == label_idx
neg_example = target != label_idx
pos_num = pos_example.sum()
neg_num = mask.sum() - (pos_num - (mask_neg & pos_example).sum())
keep_num = min(int(pos_num * self.ohem_ratio / logits_size), neg_num)
if keep_num > 0:
# masked_select: 返回一个1-D tensor,根据flat_input对应于mask=1的部分的值返回
# 得到负样本对于本标签的打分
neg_scores = torch.masked_select(
flat_input,
neg_example.view(-1, 1).bool()
).view(-1, logits_size)
neg_scores_idx = neg_scores[:, label_idx]
neg_scores_sort, _ = torch.sort(neg_scores_idx, )
threshold = neg_scores_sort[-keep_num + 1]
# 预测为本标签或者正确标签是本标签
cond = (torch.argmax(flat_input, dim=1) == label_idx & flat_input[:, label_idx] >= threshold) | pos_example.view(-1)
ohem_mask_idx = torch.where(cond, 1, 0)
flat_input_idx = flat_input[:, label_idx]
flat_target_idx = flat_target[:, label_idx]
flat_input_idx = flat_input_idx * ohem_mask_idx
flat_target_idx = flat_target_idx * ohem_mask_idx
else:
flat_input_idx = flat_input[:, label_idx]
flat_target_idx = flat_target[:, label_idx]
loss_idx = self._compute_dice_loss(
flat_input_idx.view(-1, 1),
flat_target_idx.view(-1, 1)
)
if loss is None:
loss = loss_idx
else:
loss += loss_idx
return loss
else:
for label_idx in range(logits_size):
pos_example = target == label_idx
flat_input_idx = flat_input[:, label_idx] # 概率值
flat_target_idx = flat_target[:, label_idx] # 有正样本,有负样本 【N, 】
loss_idx = self._compute_dice_loss(flat_input_idx.view(-1, 1), flat_target_idx.view(-1, 1))
if loss is None:
loss = loss_idx
else:
loss += loss_idx
return loss
def _binary_class(self, input, target, mask=None):
flat_input = input.view(-1)
flat_target = target.view(-1).float()
flat_input = torch.sigmoid(flat_input) if self.with_logits else flat_input
if mask is not None:
mask = mask.float()
flat_input = flat_input * mask
flat_target = flat_target * mask
else:
mask = torch.ones_like(target)
if self.ohem_ratio > 0:
pos_example = target > 0.5
neg_example = target <= 0.5
mask_neg_num = mask <= 0.5
pos_num = pos_example.sum() - (pos_example & mask_neg_num).sum()
neg_num = neg_example.sum()
keep_num = min(int(pos_num * self.ohem_ratio), neg_num)
neg_scores = torch.masked_select(flat_input, neg_example.bool())
neg_scores_sort, _ = torch.sort(neg_scores, )
threshold = neg_scores_sort[-keep_num+1]
cond = (flat_input > threshold) | pos_example.view(-1)
ohem_mask = torch.where(cond, 1, 0)
flat_input = flat_input * ohem_mask
flat_target = flat_target * ohem_mask
return self._compute_dice_loss(flat_input, flat_target)
def __str__(self):
return f"Dice Loss smooth:{self.smooth}, ohem: {self.ohem_ratio}, alpha: {self.alpha}"
def __repr__(self):
return str(self)