nn.CrossEntropyLoss()
是nn.logSoftmax()
和nn.NLLLoss()
的整合,可以直接使用它来替换网络中的这两个操作,这个函数可以用于多分类问题。具体的计算过程可以参考官网的公式或者一下这个链接。
https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html#torch.nn.CrossEntropyLoss
https://blog.csdn.net/geter_CS/article/details/84857220
https://zhuanlan.zhihu.com/p/98785902
1. 函数的参数
- weight(Tensor, optional):如果输入这个参数的话必须是一个1维的tensor,长度为类别数C,每个值对应每一类的权重。
- reduction (string, optional) :指定最终的输出类型,默认为’mean’。
none | 无操作 |
---|---|
mean | 输出的结果求均值 |
sum | 输出的结果求和 |
- 其他参数暂时没有用到,回头用到再补充。
2. 函数的使用方法
- 传入input以及target即可,但是需要注意两者的格式,具体见第3点。
- 调用的时候需要注意,不能直接调用,查看第三点例子。
3. 使用注意事项
- 如果不同类别对应的权重不同,传入的权重参数应该是一个1维的tensor。
- 输入的每一类的置信度得分(
input
)应该是原始的,未经过softmax或者normalized。原因是这个函数会首先对输入的原始得分进行softmax,所以必须保证输入的是每一类的原始得分。不能写成[0.2, 0.36, 0.44]
这种softmax之后的或者[0, 1, 0]
这种one-hot编码。 - 输入的
target
也不能是one-hot标签,直接输入每个例子对应的类别编号就行了(0 < target_value < C-1
),比如产生的结果数为N*C
(N为个数,C为类别数),那么输入的target
必须输入一个长度为N