【无标题】

手动实现torch.nn.CrossEntropyLoss()

损失函数手动实现

def cross_entropy_loss(y_pred , y_true):
    one_hot = F.one_hot(y_true).float()  # 对真实labels进行one_hot编码
    softmax = torch.exp(y_pred) / torch.sum(torch.exp(y_pred), dim=1).reshape(-1, 1)#将预测结果进行softmax
    logsoftmax = torch.log(softmax)
    nllloss = -torch.sum(one_hot * logsoftmax) / y_true.shape[0]
    
    return nllloss

测试手写损失函数的有效性

相同相同预测张量与真实值输入手写损失函数与nn.CrossEntropyLoss(),输出结果一致,则证明手写损失函数有效

假设该分类任务有7个类别
logits表示预测输出,尺寸为[batch_size,num_labels],即[7,7]
labels表示真实标签,尺寸为[batch_size],即[7]

注意:因为总类别数(num_labels)为7,所以labels必须包含0-6所有,否则one_hot编码后尺寸不对,one_hot * logsoftmax运算会报错

logits = torch.tensor([[-0.1378,  0.3560, -0.1881, -0.1667, -0.1741,  0.3571, -0.2159],
                       [-0.1421,  0.3805, -0.1834, -0.1769, -0.2232,  0.3823, -0.2155],
                       [-0.0975,  0.4085, -0.1694, -0.1913, -0.2099,  0.3725, -0.2611],
                       [-0.1729,  0.4593, -0.1960, -0.2224, -0.2111,  0.4147, -0.3108],
                       [-0.1913,  0.4942, -0.2568, -0.1854, -0.2764,  0.4796, -0.3428],
                       [-0.4058,  0.9523, -0.5067, -0.4886, -0.5095,  0.7811, -0.5855],
                       [-0.3806,  1.1320, -0.5704, -0.5436, -0.6376,  0.9648, -0.6508]
                       ])
labels=torch.tensor([0,1,2,3,4,5,6])
loss_value = cross_entropy_loss(logits,labels)
print("自定义函数实现",loss_value)

criterion = nn.CrossEntropyLoss()
loss = criterion(logits, labels)
print("函数实现",loss)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值