多分类问题的交叉熵
在多分类问题中,损失函数(loss function)为交叉熵(cross entropy)损失函数。对于样本点(x,y)来说,y是真实的标签,在多分类问题中,其取值只可能为标签集合labels. 我们假设有K个标签值,且第i个样本预测为第k个标签值的概率为 p_{i,k}, 即p_{i,k} = Pr(t_{i,k} = 1), 一共有N个样本,则该数据集的损失函数为
一个例子
在Python的sklearn模块中,提供了一个函数log_loss()来计算多分类问题的交叉熵。再根据我们在博客Sklearn中二分类问题的交叉熵计算对log_loss()函数的源代码的分析,我们不难利用上面的计算公式用自己的方法来实现交叉熵的求值。
我们给出的例子如下:
y_true = ['1', '4', '5'] # 样本的真实标签
y_pred = [[0.1, 0.6, 0.3, 0, 0, 0, 0, 0, 0, 0],
[0, 0.3, 0.2, 0, 0.5, 0, 0, 0, 0, 0],
[0.6, 0.3, 0, 0, 0, 0.1, 0, 0