- weight:可选的,应该是一个tensor,里面的值对应类别的权重,如果样本不均衡的话,这个参数非常有用,长度是类别数目
- szie_average:默认是True,会将mini-batch的loss求平均值;否则就是把loss累加起来
import torch
import torch.nn as nn
a = torch.Tensor([[1,2,3]])
target = torch.Tensor([2]).long()
logsoftmax = nn.LogSoftmax()
ce = nn.CrossEntropyLoss()
nll = nn.NLLLoss()
#测试CrossEntropyLoss
cel = ce(a,target)
print(cel)
#输出:tensor(0.4076)
#测试LogSoftmax+NLLLoss
lsm_a = logsoftmax(a)
nll_lsm_a = nll(lsm_a,target)
#输出tensor(0.4076)
看来直接用nn.CrossEntropy和nn.LogSoftmax+nn.NLLLoss是一样的结果