交叉熵
直接进入正题,对于一个二分类问题,我们可以定义如下的交叉熵代价函数,
其中,n代表总的样本数(batch)。对于真实类别为1的样本,其对应的损失为lna, 对于真实类别为0的样本,对应的损失为ln(1-a)。
而将其当做代价函数,本人清楚一下两点原因,其一,它是非负的,其二,如果对于所有的样本输入,其对应的输出越接近真实结果,那么这个loss将越趋向于0。
交叉熵更加常用于多分类任务,此时,这个损失函数应该改写成:
其中,(其值通常为1)中的i代表哪一类样本是正确的,代表正确的那一类预测出来的概率。
例如对于以下一段代码,可以明确的计算出CrossEntropyLoss的值: