交叉熵的数学推导和手撕代码

交叉熵的数学推导和手撕代码

数学推导

在这里插入图片描述

手撕代码

import torch
import torch.nn.functional as F

# 二元交叉熵损失函数
def binary_cross_entropy(predictions, targets):
    # predictions应为sigmoid函数的输出,即概率值
    # targets应为0或1的二进制标签
    loss = -torch.mean(targets * torch.log(predictions) + (1 - targets) * torch.log(1 - predictions))
    return loss

# 多元交叉熵损失函数(使用softmax处理predictions)
def categorical_cross_entropy(predictions, targets):
    # predictions应为softmax函数的输出,即各类的概率分布
    # targets应为类别的索引(整数),通常使用torch.nn.functional.cross_entropy直接计算更为简便
    # 但为了演示,这里手动实现
    predictions = F.softmax(predictions, dim=1)  # 确保predictions是经过softmax处理的
    targets = F.one_hot(targets, num_classes=predictions.shape[1]).float()  # 将targets转换为one-hot编码
    loss = -torch.mean(torch.sum(targets * torch.log(predictions + 1e-9), dim=1))  # 加上小常数防止log(0)
    return loss

# 示例
if __name__ == "__main__":
    # 假设有10个样本,每个样本预测为二分类问题的概率
    predictions_binary = torch.randn(10, 1, requires_grad=True)
    targets_binary = torch.randint(0, 2, (10, 1)).float()
    
    # 计算二元交叉熵
    loss_binary = binary_cross_entropy(torch.sigmoid(predictions_binary), targets_binary)
    print(f"Binary Cross Entropy Loss: {loss_binary.item()}")
    
    # 假设有10个样本,每个样本预测为3分类问题的原始分数
    predictions_categorical = torch.randn(10, 3, requires_grad=True)
    targets_categorical = torch.randint(0, 3, (10,))
    
    # 计算多元交叉熵
    loss_categorical = categorical_cross_entropy(predictions_categorical, targets_categorical)
    print(f"Categorical Cross Entropy Loss: {loss_categorical.item()}")

    # 注意:在实际应用中,通常直接使用torch.nn.CrossEntropyLoss()来计算多元交叉熵,因为它内部已经包含了softmax操作
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

醉后才知酒浓

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值