pytorch weighted_and_neg_topk_cross_entropy 加权的负权重topk交叉熵损失

该文介绍了一种用于GPT模型训练的改进损失函数,结合加权和数据增强。新方法在目标类别权重为负且未出现在预测的前K个高概率类别时忽略梯度计算,以适应NLG任务。此优化可能导致非单调的Loss曲线,但能提高模型性能。
摘要由CSDN通过智能技术生成

根据这段时间的NLG经验,继续改进损失函数。

主要用于以下文章所写的 NLP 增强管道。

一种 用于GPT模型 训练的 包含加权 和 数据增强 和 损失方法 的设计
https://blog.csdn.net/ONE_SIX_MIX/article/details/129682576

相比上面文章里改的的loss,加入topk 负类型测试,当负权重的类别在预测类别前 K 的高概率类别时,才会传递梯度,否则会跳过

import torch
import torch.nn.functional as F
from typing import Optional


@torch.jit.script
def weighted_and_neg_topk_cross_entropy(
    input: torch.Tensor,
    target: torch.Tensor,
    topk: Optional[int]=None,
    target_weight: Optional[torch.Tensor]=None,
    target_mask: Optional[torch.Tensor]=None,
    label_smoothing: float=0.,
    ignore_zero_target_weight: bool=True,
):
    '''
    加权的负权重topk交叉熵损失,主要用于NLG任务,对基于Sample的生成方式比较有效。
    主要行为:
    如果 目标项 对应的 target_weight 权重大于0,按照正常的交叉熵来计算
    如果 目标项 对应的 target_weight 权重小于0,先检查 目标项 的类别是否在 预测项前topk个高概率预测中,如果在,则按正常交叉熵来计算,如果不在,则跳过该项的计算。
    行为目的:
    忽略负向权重已掉出前topk的预测类别的梯度计算

    注意:
    如果使用了负向权重,在模型性能越好时,Loss值并非是单调下降的,可能会上升。并且Loss值可以小于0,然后Loss最小时并不是最优(训练最优)模型。如果需要用于评估,需要结合其他指标来评估。
    例如,可以使用多数为负值的 target_weight,可以发现 Loss 值是负的,然后在收敛后期时,Loss会反弹到0值处。

    以下维度缩写,B 代表批量大小,C 代表词向量维度
    虽然写着形状是 [B,C,...] 和 [B,...]
    :param input:                       FloatTensor shape [B,C,...] , 模型的输出
    :param target:                      LongTensor shape [B,...] , 预测目标
    :param topk:                        int or None , 检查前k个预测,None代表不使用,推荐使用10
    :param target_weight:               FloatTensor shape [B,...] or None , 每个目标的权值
    :param target_mask:                 BoolTensor shape [B,...] or None , 目标的掩码,True代表参与计算,False代表忽略
    :param label_smoothing:             float , 标签平滑
    :param ignore_zero_target_weight:   bool, 是否忽略 target_weight 中为0的目标,使其不参与梯度计算
    :return:
    '''
    assert target.shape[0] == input.shape[0] and target.shape[1:] == input.shape[2:], 'Error! Bad input and target shape.'
    assert topk is None or 0 < topk <= input.shape[1], 'Error! Bad param topk.'
    assert target_weight is None or target.shape == target_weight.shape, 'Error! Bad target_weight shape.'
    assert target_mask is None or (target.shape == target_mask.shape and target_mask.dtype == torch.bool), 'Error! Bad target_mask shape or dtype.'

    loss = F.cross_entropy(input, target, label_smoothing=label_smoothing, reduction='none')

    if target_weight is not None:
        loss = loss * target_weight

    if target_mask is None:
        target_mask = torch.full_like(target, 1, dtype=torch.bool)

    if target_weight is not None and topk is not None:
        # 如果负向权重的目标类别不在前K个列表中时,则跳过
        out_topk_cls = torch.topk(input.detach(), topk, dim=1, sorted=False)[1]
        # 筛选出 权重为负的,并且预测类别在前k个最高概率里的项
        neg_cls_slient_mask = torch.logical_and(~(target[:, None] == out_topk_cls).max(dim=1)[0], target_weight < 0)
        # 取反
        inv_neg_cls_slient_mask = ~neg_cls_slient_mask
        # 应用到 mask 上,即额外排除掉 权重为负的,并且预测类别不在前k个最高概率里的项
        target_mask = target_mask & inv_neg_cls_slient_mask

    if ignore_zero_target_weight and target_weight is not None:
        target_mask = target_mask & ~(target_weight == 0.)

    if target_mask.any().item():
        loss = loss[target_mask].mean()
    else:
        # 如果 mask 全部均为 False,代表 loss 为 0,为确保loss可以backward,所以使用 mul(0.) 处理
        loss = loss.sum().mul(0.)

    return loss


if __name__ == '__main__':
    a = torch.rand([1, 10])
    a[0, 1]+=5
    t = torch.zeros([1],dtype=torch.long) + 1

    a.requires_grad = True
    optim = torch.optim.Adam([a],lr=1e-2)

    for i in range(1000):
        optim.zero_grad()
        loss = weighted_and_neg_topk_cross_entropy(a, t, 9, torch.as_tensor([-0.1]), torch.as_tensor([True]), 0)
        loss.backward()
        optim.step()
        print(loss, a.tolist())

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值