NLLLOSS(weight=)中weight解释

https://blog.csdn.net/weixin_37724529/article/details/107021786?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522162144406116780271573141%2522%252C%2522scm%2522%253A%252220140713.130102334.pc%255Fall.%2522%257D&request_id=162144406116780271573141&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~all~first_rank_v2~rank_v29-15-107021786.first_rank_v2_pc_rank_v29&utm_term=+weight+nllloss&spm=1018.2226.3001.4187

 

  • weight:可选的,应该是一个tensor,里面的值对应类别的权重,如果样本不均衡的话,这个参数非常有用,长度是类别数目
  • szie_average:默认是True,会将mini-batch的loss求平均值;否则就是把loss累加起来
import torch
import torch.nn as nn
 
a = torch.Tensor([[1,2,3]])
target = torch.Tensor([2]).long()
logsoftmax = nn.LogSoftmax()
ce = nn.CrossEntropyLoss()
nll = nn.NLLLoss()
 
#测试CrossEntropyLoss
cel = ce(a,target)
print(cel)
#输出:tensor(0.4076)
 
#测试LogSoftmax+NLLLoss
lsm_a = logsoftmax(a)
nll_lsm_a = nll(lsm_a,target)
#输出tensor(0.4076)

看来直接用nn.CrossEntropy和nn.LogSoftmax+nn.NLLLoss是一样的结果

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值