https://blog.csdn.net/CaiDaoqing/article/details/90457197blog.csdn.net
为防止网页关闭,所做的笔记。。
Dice loss
![e8e80472d31011aa8ed85b32d036b3b1.png](https://i-blog.csdnimg.cn/blog_migrate/bea4a81608bec66d27a3c7b4ff67d157.png)
import torch
import torch.nn as nn
class DiceLoss(nn.Module):
def __init__(self):
super(DiceLoss, self).__init__()
def forward(self, input, target):
N = target.size(0)
smooth = 1
input_flat = input.view(N, -1)
target_flat = target.view(N, -1)
intersection = input_flat * target_flat
loss = 2 * (intersection.sum(1) + smooth) / (input_flat.sum(1) + target_flat.sum(1) + smooth)
loss = 1 - loss.sum() / N
return loss
class MulticlassDiceLoss(nn.Module):
"""
requires one hot encoded target. Applies DiceLoss on each class iteratively.
requires input.shape[0:1] and target.shape[0:1] to be (N, C) where N is
batch size and C is number of classes
"""
def __init__(self):
super(MulticlassDiceLoss, self).__init__()
def forward(self, input, target, weights=None):
C = target.shape[1]
# if weights is None:
# weights = torch.ones(C) #uniform weights for all classes
dice = DiceLoss()
totalLoss = 0
for i in range(C):
diceLoss = dice(input[:,i], target[:,i])
if weights is not None:
diceLoss *= weights[i]
totalLoss += diceLoss
return totalLoss
IOU loss
![3f81db40fcafe42482ab56b1e94812d3.png](https://i-blog.csdnimg.cn/blog_migrate/e6bee979f3e451977f0a5ec355404e41.png)
#针对多分类问题,二分类问题更简单一点。
import torch
import torch.nn as nn
import torch.nn.functional as F
class SoftIoULoss(nn.Module):
def __init__(self, n_classes):
super(SoftIoULoss, self).__init__()
self.n_classes = n_classes
@staticmethod
def to_one_hot(tensor, n_classes):
n, h, w = tensor.size()
one_hot = torch.zeros(n, n_classes, h, w).scatter_(1, tensor.view(n, 1, h, w), 1)
return one_hot
def forward(self, input, target):
# logit => N x Classes x H x W
# target => N x H x W
N = len(input)
pred = F.softmax(input, dim=1)
target_onehot = self.to_one_hot(target, self.n_classes)
# Numerator Product
inter = pred * target_onehot
# Sum over all pixels N x C x H x W => N x C
inter = inter.view(N, self.n_classes, -1).sum(2)
# Denominator
union = pred + target_onehot - (pred * target_onehot)
# Sum over all pixels N x C x H x W => N x C
union = union.view(N, self.n_classes, -1).sum(2)
loss = inter / (union + 1e-16)
# Return average loss over classes and batch
return -loss.mean()
Lovasz-Softmax loss
https://arxiv.org/abs/1705.08790arxiv.orgThe Lovász-Softmax loss: A tractable surrogate for the optimization of the intersection-over-union measure in neural networks https://arxiv.org/abs/1705.08790arxiv.org论文:The Lovász-Softmax loss: A tractable surrogate for the optimization of the intersection-over-union measure in neural networks论文:
https://arxiv.org/abs/1705.08790arxiv.orgimport torch
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
def lovasz_grad(gt_sorted):
"""
Computes gradient of the Lovasz extension w.r.t sorted errors
See Alg. 1 in paper
"""
p = len(gt_sorted)
gts = gt_sorted.sum()
intersection = gts - gt_sorted.float().cumsum(0)
union = gts + (1 - gt_sorted).float().cumsum(0)
jaccard = 1. - intersection / union
if p > 1: # cover 1-pixel case
jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
return jaccard
# --------------------------- BINARY LOSSES ---------------------------
def lovasz_hinge(logits, labels, per_image=True, ignore=None):
"""
Binary Lovasz hinge loss
logits: [B, H, W] Variable, logits at each pixel (between -infty and +infty)
labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
per_image: compute the loss per image instead of per batch
ignore: void class id
"""
if per_image:
loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore))
for log, lab in zip(logits, labels))
else:
loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore))
return loss
def lovasz_hinge_flat(logits, labels):
"""
Binary Lovasz hinge loss
logits: [P] Variable, logits at each prediction (between -infty and +infty)
labels: [P] Tensor, binary ground truth labels (0 or 1)
ignore: label to ignore
"""
if len(labels) == 0:
# only void pixels, the gradients should be 0
return logits.sum() * 0.
signs = 2. * labels.float() - 1.
errors = (1. - logits * Variable(signs))
errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
perm = perm.data
gt_sorted = labels[perm]
grad = lovasz_grad(gt_sorted)
loss = torch.dot(F.relu(errors_sorted), Variable(grad))
return loss
def flatten_binary_scores(scores, labels, ignore=None):
"""
Flattens predictions in the batch (binary case)
Remove labels equal to 'ignore'
"""
scores = scores.view(-1)
labels = labels.view(-1)
if ignore is None:
return scores, labels
valid = (labels != ignore)
vscores = scores[valid]
vlabels = labels[valid]
return vscores, vlabels
# --------------------------- MULTICLASS LOSSES ---------------------------
def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None):
"""
Multi-class Lovasz-Softmax loss
probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).
Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
per_image: compute the loss per image instead of per batch
ignore: void class labels
"""
if per_image:
loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes)
for prob, lab in zip(probas, labels))
else:
loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes)
return loss
def lovasz_softmax_flat(probas, labels, classes='present'):
"""
Multi-class Lovasz-Softmax loss
probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
labels: [P] Tensor, ground truth labels (between 0 and C - 1)
classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
"""
if probas.numel() == 0:
# only void pixels, the gradients should be 0
return probas * 0.
C = probas.size(1)
losses = []
class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
for c in class_to_sum:
fg = (labels == c).float() # foreground for class c
if (classes is 'present' and fg.sum() == 0):
continue
if C == 1:
if len(classes) > 1:
raise ValueError('Sigmoid output possible only with 1 class')
class_pred = probas[:, 0]
else:
class_pred = probas[:, c]
errors = (Variable(fg) - class_pred).abs()
errors_sorted, perm = torch.sort(errors, 0, descending=True)
perm = perm.data
fg_sorted = fg[perm]
losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted))))
return mean(losses)
def flatten_probas(probas, labels, ignore=None):
"""
Flattens predictions in the batch
"""
if probas.dim() == 3:
# assumes output of a sigmoid layer
B, H, W = probas.size()
probas = probas.view(B, 1, H, W)
B, C, H, W = probas.size()
probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C
labels = labels.view(-1)
if ignore is None:
return probas, labels
valid = (labels != ignore)
vprobas = probas[valid.nonzero().squeeze()]
vlabels = labels[valid]
return vprobas, vlabels
Focal loss
![c6ba25edf9a17dc1d22adc8194095c24.png](https://i-blog.csdnimg.cn/blog_migrate/ca941eccc3765d4ee77feafb18b51c61.png)
focal loss是交叉熵上的变种:给不同类别的样本loss加权重!! ARFA
易分样本,(1-pt)^beta 会很小,这样易分样本的loss会显著减小,模型就会更关注难分样本loss的优化。 GAMMA
import torch
import torch.nn as nn
# --------------------------- BINARY LOSSES ---------------------------
class FocalLoss(nn.Module):
def __init__(self, alpha=0.25, gamma=2, weight=None, ignore_index=255):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.weight = weight
self.ignore_index = ignore_index
self.bce_fn = nn.BCEWithLogitsLoss(weight=self.weight)
def forward(self, preds, labels):
if self.ignore_index is not None:
mask = labels != self.ignore
labels = labels[mask]
preds = preds[mask]
logpt = -self.bce_fn(preds, labels)
pt = torch.exp(logpt)
loss = -((1 - pt) ** self.gamma) * self.alpha * logpt
return loss
# --------------------------- MULTICLASS LOSSES ---------------------------
class FocalLoss(nn.Module):
def __init__(self, alpha=0.5, gamma=2, weight=None, ignore_index=255):
super().__init__()
self.alpha = alpha
self.gamma = gamma
self.weight = weight
self.ignore_index = ignore_index
self.ce_fn = nn.CrossEntropyLoss(weight=self.weight, ignore_index=self.ignore_index)
def forward(self, preds, labels):
logpt = -self.ce_fn(preds, labels)
pt = torch.exp(logpt)
loss = -((1 - pt) ** self.gamma) * self.alpha * logpt
return loss
OHEM
在线困难样本挖掘,即根据loss的大小,选择有较大loss的像素反向传播,较小loss的像素梯度为0。
def focal_loss(self, output, target, alpha, gamma, OHEM_percent):
output = output.contiguous().view(-1)
target = target.contiguous().view(-1)
max_val = (-output).clamp(min=0)
loss = output - output * target + max_val + ((-max_val).exp() + (-output - max_val).exp()).log()
# This formula gives us the log sigmoid of 1-p if y is 0 and of p if y is 1
invprobs = F.logsigmoid(-output * (target * 2 - 1))
focal_loss = alpha * (invprobs * gamma).exp() * loss
# Online Hard Example Mining: top x% losses (pixel-wise). Refer to http://www.robots.ox.ac.uk/~tvg/publications/2017/0026.pdf
OHEM, _ = focal_loss.topk(k=int(OHEM_percent * [*focal_loss.shape][0]))
return OHEM.mean()