二分类任务交叉熵损失函数定义
多分类任务的交叉熵损失函数定义为:
其中
是向量,
表示样本预测为第c类的概率。
如果是二分类任务的话,因为只有正例和负例,且两者的概率和是1,所以不需要预测一个向量,只需要预测一个概率就好了,损失函数定义简化如下:
其中
是模型预测样本是正例的概率,
是样本标签,如果样本属于正例,取值为1,否则取值为0。
PyTorch中二分类交叉熵损失函数的实现
PyTorch提供了两个类来计算二分类交叉熵(Binary Cross Entropy),分别是BCELoss() 和BCEWithLogitsLoss()torch.nn.BCELoss()
类定义如下
torch.nn.BCELoss(
weight=None,
size_average=None,
reduction="mean",
)
用N表示样本数量,
表示预测第n个样本为正例的概率,