多分类问题,选用交叉熵损失函数pytorch框架下的nn.CrossEntropyLoss()
注意:
1、该函数输入的预测结果应是原始结果,不能经过softmax、标准化normalized或者进行argmax后得到的[0,1,0]此类独热编码此类操作。
2、标签值的输入也不可以是独热编码,直接输入数据对应的分类label[1,6,2,4]即可。
3、torch
框架下nn.CrossEntropyLoss()
标签值是不需要one hot编码的。
注意:
1、该函数输入的预测结果应是原始结果,不能经过softmax、标准化normalized或者进行argmax后得到的[0,1,0]此类独热编码此类操作。
2、标签值的输入也不可以是独热编码,直接输入数据对应的分类label[1,6,2,4]即可。
3、torch
框架下nn.CrossEntropyLoss()
标签值是不需要one hot编码的。