在修改网络的时候遇到了几种交叉熵分别是:
- torch.nn.functional.binary_cross_entropy
- torch.nn.functional.cross_entropy
接下来分别对其进行介绍
一、二分类交叉熵损失
torch.nn.functional.binary_cross_entropy其中的binary的意思就是二进制。适用于二分类问题,其中每个样本只能属于两个类别之一
下面举例说明:# 示例用法
loss = F.binary_cross_entropy(torch.sigmoid(output), target)
其中输入label的位置也就是上面torch.sigmoid(output)
部分可以看出,这个位置的输入需要经过归一化,将值映射到0-1之间的对样本的预测。而target位置呢需要输入float的数据类型,表示真实标签。其运算公式是:
二、多分类交叉熵损失
torch.nn.functional.cross_entropy适合多分类问题。下面举例
loss = F.cross_entropy(output, target)
output位置的输入是模型输入的原始分数即可,不需要经过softmax的激活函数。输入的target即为样本的标签,要求其元素的数据类型为整数类型,表示每个样本的真实类别。
计算公式是:
总的来说,主要区别在于输入的形式和适用问题的类型。binary_cross_entropy 用于二分类问题,需要经过 sigmoid 激活函数,而 cross_entropy 用于多分类问题,不需要 softmax 激活函数。选择合适的函数取决于你的问题和模型的输出形式。