pytoch实现nn.CrossEntropyLoss和多分类的focal loss

本文对比了PyTorch中的nn.CrossEntropyLoss、Focal Loss的实现,并提供了自定义函数,展示了如何使用权重调整和不同reduction策略。通过代码实例演示了三种损失函数的性能和用法。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

现在不太想写关于nn.CrossEntropyLoss和focal loss的解析,都还是比较容易理解的,这里直接贴上代码:

import time
import torch
import torch.nn as nn
import torch.nn.functional as F

# for weight variable:
# https://discuss.pytorch.org/t/passing-the-weights-to-crossentropyloss-correctly/14731/10?u=simiao_lai1

def crossentropy(logits, targets, weight=None, reduction='mean'):
    """N samples, C classes
    logits: [N, C]
    targets:[N] range [0, C-1]
    weight: [C]
    """
    C = logits.size(1)
    if weight is not None:
        assert len(weight)==C, 'weight length must be equal to classes number'
        assert weight.dim() == 1, 'weight dim must be 1'

    else:
        weight = torch.ones(C)

    log_prob = F.log_softmax(logits, dim=1)
    tar_one_hot = F.one_hot(targets, num_classes=C).type(torch.float32)

    loss = weight[targets] * (-log_prob * tar_one_hot).sum(dim=1)
    if reduction == 'mean':
        loss = loss.sum() / (weight[targets].sum() + 1e-7)
    elif reduction == 'none':
        loss = loss

    return loss

def focalloss(logits, targets, gamma=0.0, weight=None, reduction='mean'):
    """N samples, C classes
    logits: [N, C]
    targets:[N] range [0, C-1]
    gamma: factor(default 0.0, that is standard cross entropy)
    weight: [C]
    """
    N, C = logits.size(0), logits.size(1)
    if weight is not None:
        assert len(weight) == C, 'weight length must be equal to classes number'
        assert weight.dim() == 1, 'weight dim must be 1'

    else:
        weight = torch.ones(C)

    prob = F.softmax(logits, dim=1)
    log_prob = F.log_softmax(logits, dim=1)
    tar_one_hot = F.one_hot(targets, num_classes=C).type(torch.float32)
    factor= (1 - prob[range(N), targets]) ** gamma

    loss = factor * weight[targets] * (-log_prob * tar_one_hot).sum(dim=1)
    if reduction == 'mean':
        loss = loss.sum() / (weight[targets].sum() + 1e-7)
    elif reduction == 'none':
        loss = loss

    return loss


N, C = 3, 6
logit = torch.randn(N, C)
target = torch.randint(0, C, (N,))
weight = torch.rand(C)

criterion = nn.CrossEntropyLoss(weight=weight, reduction='none')
s = time.time()
loss1 = criterion(logit, target)
print(time.time()-s)
s = time.time()
loss2 = crossentropy(logit, target, weight=weight, reduction='none')
print(time.time()-s)
s = time.time()
loss3 = focalloss(logit, target, gamma=0.5, weight=weight, reduction='none')
print(time.time()-s)

print(loss1)
print(loss2)
print(loss3)
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

laizi_laizi

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值