什么是多标签分类
图片中是否包含房子?你的回答就是有或者没有,这就是一个典型的二分类问题(一个问题两个选项,是或不是)。
同样还是这幅照片,问题变成了:这幅照片是谁拍摄的?备选答案你,你的父亲,你的母亲,这就变成了一个多分类问题(一个问题多个选项)。
若此时问题如下:
你会发现图中所示的答案有多个yes,而不同于之前的多分类只有一个yes。这就是多标签分类。
多标签的问题的损失函数是什么
这里需要先了解一下softmax 与 sigmoid函数
这两个函数最重要的区别,我们观察一下:
区别还是很明显的。
综上,我们可以得出以下结论:
pytorch中的实现
torch.nn.CrossEntropyLoss()
交叉熵快速理解参考:博客2
CrossEntropyLoss(pre, label, ignore_index = -1)
该函数会默认对预测值pre做softmax后,再计算交叉熵。
BCELoss
在图片多标签分类时,如果3张图片分3类,会输出一个3*3的矩阵。
需要先用Sigmoid给这些值都映射到0~1之间:
假设Target是:
计算后的loss为:
BCEWithLogitsLoss
BCEWithLogitsLoss就是把Sigmoid+BCELoss合成一步。我们直接用刚刚的input验证一下是不是0.7193:
补充阅读:
关于交叉熵损失函数详解:博客1