pytorch-torch.nn.CrossEntropyLoss重写
最关键的一部就是将target改成one-hot形式,即[minibatch] -> [minibatch, C]
import torch
import torch.nn as nn
import torch.nn.functional as f
class CrossEntropy():
def __init__(self, weight=None, reduction='mean', ignore_index=-100, label_smoothing=None):
self.reduction = reduction
self.ignore_index = ignore_index
self.label_smoothing = label_smoothing
self.weight = weight
def __call__(self, input, target):
ignore_index = torch.where(target == self.ignore_index, 0, 1)
weight = torch.ones((input.size()[-1],)) if not self.weight else self.weight
target = torch.zeros_like(input).scatter(-1, target.unsqueeze(-1), 1)
input = -(torch.log(torch.softmax(input, dim=-1)))
if self.label_smoothing:
target = torch.where(target == 1, 1 - self.label_smoothing, self.label_smoothing / (input.size()[-1] - 1))
L = (target * input * weight).sum(-1)
return L.sum() / input.size()[0] if self.reduction == 'mean' else L.sum() # 注意label_smoothing只使用一次
L = (target * input * weight).sum(-1)
L = L * ignore_index
return L.sum() / ((target * weight).sum(-1) * ignore_index).sum() if self.reduction == 'mean' else L.sum()
cross_entropy = CrossEntropy(label_smoothing=0.3, reduction='sum')
input = torch.randn(10, 4).type(torch.DoubleTensor)
target = torch.randint(0, 4, (10,))
print(cross_entropy(input, target))
print(f.cross_entropy(input, target, label_smoothing=0.3, reduction='sum'))