CrossEntropyLoss的理解

本文介绍了PyTorch中的CrossEntropyLoss,包括它的计算公式、使用方式以及输入输出的注意事项。通过手动计算和使用nn模块的LogSoftmax与NLLLoss来理解CrossEntropyLoss。特别强调了输入(input)应为(N,C),目标(target)为类别索引,且当batch大于1时,target形状应为(N)。同时讨论了target为概率的情况,指出在这种情况下需要与input形状相同,且不应使用多目标标签。" 104143061,6376822,理解Spring框架的AOP机制,"['Java', 'Spring', 'AOP', '事务管理', '日志']
摘要由CSDN通过智能技术生成

对于交叉熵,一直以来都是直接运用公式,最近稍微了解了一下

1 公式

nn.CrossEntropyLoss结合nn.LogSoftmax和nn.NLLLoss(负对数似然),也就是先计算Softmax,再计算NLLLoss

 手动计算

import torch
# 手动计算
y = torch.tensor([1,0,0])
z = torch.tensor([0.2,0.1,-0.1])
y_pred = np.exp(z) / np.exp(z).sum()
loss = (-y * np.log(y_pred)).sum()
print(loss) # tensor(0.9729)
  

使用nn.LogSoftmax和nn.NLLLoss

criterion = torch.nn.LogSoftmax()   
z_tensor = torch.tensor([0.2, 0.1, -0.1])
z_tensor = criterion(z_tensor)
print(z_tensor) # tensor([-0.9729, -1.0729, -1.2729])

criterion = torch.nn.NLLLoss()   
y_tensor = torch.LongTensor([0])
loss = criterion(z_tensor.reshape(1,3), y_tensor)
pr
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值