现在不太想写关于nn.CrossEntropyLoss和focal loss的解析,都还是比较容易理解的,这里直接贴上代码:
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
# for weight variable:
# https://discuss.pytorch.org/t/passing-the-weights-to-crossentropyloss-correctly/14731/10?u=simiao_lai1
def crossentropy(logits, targets, weight=None, reduction='mean'):
"""N samples, C classes
logits: [N, C]
targets:[N] range [0, C-1]
weight: [C]
"""
C = logits.size(1)
if weight is not None:
assert len(weight)==C, 'weight length must be equal to classes number'
assert weight.dim() == 1, 'weight dim must be 1'
else:
weight = torch.ones(C)
log_prob = F.log_softmax(logits, dim=1)
tar_one_hot = F.one_hot(targets, num_classes=C).type(torch.float32)
loss = weight[targets] * (-log_prob * tar_one_hot).sum(dim=1)
if reduction == 'mean':
loss = loss.sum() / (weight[targets].sum() + 1e-7)
elif reduction == 'none':
loss = loss
return loss
def focalloss(logits, targets, gamma=0.0, weight=None, reduction='mean'):
"""N samples, C classes
logits: [N, C]
targets:[N] range [0, C-1]
gamma: factor(default 0.0, that is standard cross entropy)
weight: [C]
"""
N, C = logits.size(0), logits.size(1)
if weight is not None:
assert len(weight) == C, 'weight length must be equal to classes number'
assert weight.dim() == 1, 'weight dim must be 1'
else:
weight = torch.ones(C)
prob = F.softmax(logits, dim=1)
log_prob = F.log_softmax(logits, dim=1)
tar_one_hot = F.one_hot(targets, num_classes=C).type(torch.float32)
factor= (1 - prob[range(N), targets]) ** gamma
loss = factor * weight[targets] * (-log_prob * tar_one_hot).sum(dim=1)
if reduction == 'mean':
loss = loss.sum() / (weight[targets].sum() + 1e-7)
elif reduction == 'none':
loss = loss
return loss
N, C = 3, 6
logit = torch.randn(N, C)
target = torch.randint(0, C, (N,))
weight = torch.rand(C)
criterion = nn.CrossEntropyLoss(weight=weight, reduction='none')
s = time.time()
loss1 = criterion(logit, target)
print(time.time()-s)
s = time.time()
loss2 = crossentropy(logit, target, weight=weight, reduction='none')
print(time.time()-s)
s = time.time()
loss3 = focalloss(logit, target, gamma=0.5, weight=weight, reduction='none')
print(time.time()-s)
print(loss1)
print(loss2)
print(loss3)