1、定义
多标签分类任务:一条数据可能有一个或者多个标签,类似于选择题中的不定项选择。
假设对图片进行分类,它的标签数目不是固定的,有的有一个标签,有的有两个标签,但标签的总数是固定的,比如 10 类。
2、标签编码方式
可以采用标签补齐的方法,将缺失的标签全部使用 0 标记,这样就不能使用 one-hot 编码了。例如,标签总数为 10,编号为 0-9,一张图片中包含 0,3,9 号标签,则这张图片的标签编码为:1 0 0 1 0 0 0 0 0 1
3、损失函数
使用神经网络处理多分类任务时,一般采用 softmax 作为输出层的激活函数,使用categorical_crossentropy(多类别交叉熵损失函数)作为损失函数,如下所示,其中
在多标签分类任务中,一般采用 sigmoid 作为输出层的激活函数,使用 binary_crossentropy(二分类交叉熵损失函数)作为损失函数,如下所示,其中
4、小结
多标签分类问题可以转换为多个二分类问题,计算一个样本各个标签的损失,然后取平均值,得到最后的损失。