torch交叉熵计算

本文介绍了在PyTorch中如何实现交叉熵损失函数,包括两种不同的代码实现方式,并详细解释了mask的作用,以及.size()和.contiguous().view(-1, out_size)在代码中的功能。" 131858346,9311708,Matlab伪谱法模拟地震波正演:代码与解析,"['matlab', '地震波模拟', '信号处理', '数值计算']
摘要由CSDN通过智能技术生成

交叉熵计算函数

第一种代码

import torch as t
import torch.nn  as nn
#  batch_size=3,计算对应每个类别的分数(只有两个类别)
score = t.randn(1, 4)
# 三个样本分别属于1,0,1类,label必须是LongTensor
label = t.Tensor([1]).long()

# loss与普通的layer无差异
criterion = nn.CrossEntropyLoss()
loss = criterion(score, label)
loss

第二种代码

def cal_loss(logits, targets, tag2id):
    """计算损失
    参数:
        logits: [B, L, out_size]  out_size为不同类别的估计值
        targets: [B, L]
    首先把target经过mask后展平为一维tensor,然后同样的对logit做mask,view后为(-1,out_size)的size,做cross_entropy
    可以考虑加value weight
    """
    PAD = tag2id.get('<pad>')
    assert PAD is not None

    mask = (targets != PAD)  # [B, L]
    targets = targets[mask]  #变成一维
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值