pytorch-torch.nn.CrossEntropyLoss重写

pytorch-torch.nn.CrossEntropyLoss重写

最关键的一部就是将target改成one-hot形式,即[minibatch] -> [minibatch, C]

import torch
import torch.nn as nn
import torch.nn.functional as f


class CrossEntropy():
    def __init__(self, weight=None, reduction='mean', ignore_index=-100, label_smoothing=None):
        self.reduction = reduction
        self.ignore_index = ignore_index
        self.label_smoothing = label_smoothing
        self.weight = weight

    def __call__(self, input, target):
        ignore_index = torch.where(target == self.ignore_index, 0, 1)
        weight = torch.ones((input.size()[-1],)) if not self.weight else self.weight
        target = torch.zeros_like(input).scatter(-1, target.unsqueeze(-1), 1)
        input = -(torch.log(torch.softmax(input, dim=-1)))

        if self.label_smoothing:
            target = torch.where(target == 1, 1 - self.label_smoothing, self.label_smoothing / (input.size()[-1] - 1))
            L = (target * input * weight).sum(-1)
            return L.sum() / input.size()[0] if self.reduction == 'mean' else L.sum()  # 注意label_smoothing只使用一次

        L = (target * input * weight).sum(-1)
        L = L * ignore_index
        return L.sum() / ((target * weight).sum(-1) * ignore_index).sum() if self.reduction == 'mean' else L.sum()


cross_entropy = CrossEntropy(label_smoothing=0.3, reduction='sum')
input = torch.randn(10, 4).type(torch.DoubleTensor)
target = torch.randint(0, 4, (10,))
print(cross_entropy(input, target))
print(f.cross_entropy(input, target, label_smoothing=0.3, reduction='sum'))

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值