摘要
多标签图像分类是一个预测一张图中若干目标、属性以及其他实体所对应的标签语义。本文提出Classification Transformer (C-Trans) 网络来实现通用的多标签图像分类任务。这个网络利用Transformer来探索视觉特征和标签之间的复杂依赖关系。主要流程是,首先利用Transformer对图像进行特征编码,从而预测出三个东西,一个是目标标签集合,另一个是给定的masked 标签,以及卷积网络中所得到的视觉特征。在这三个东西中最重要的是对训练目标进行标签的mask。这个mask是以三元组的形式来表征当前目标与其他目标之间的关联,其中mask=(positive, negative, unknown),并且只出现在训练阶段。
引言
Effective models for multi-label classification aim to extract good visual features that are predictive of image label, but also exploit the complex relations and dependencies between visual features and labels, and among labels themselves.
Transformers have demonstrated a remarkable capability of being able to exploit dependencies among sets of inputs using multi-headed self-attention layers. In our approach, a Transformers encoder is trained to reconstruct a set of target labels given an input set of masked label embeddings and a set of features obtained from a convolutional neural network.
C-Trans label masking during training to represent the state of labels as positive, negative and unknown–analogous to how language models are trained with masked tokens.
At test time, C-Tran is able to predict a set of target labels using only input visual features by masking all the input labels as unknown.
Beyond obtaining state-of-the-art results on standard multi-label classification C-Tran is a more general model for reasoning under prior label observations. Because our approach explicitly models the label state( positive, negative, or unknown during training, it can also be used at test tie with partial or extra label annotations by setting the state of some of the labels as either positive or negative instead of masking them as unknown.
In general, we consider this setting as realistic since many images also have metadata in the form of extra labels such as location or weather information.
个人理解:
训理阶段:输入是三部分组成,分别是图像、标签以及人为标注或者借助一些外部先验手段来生成的mask。然后图像先输入到CNN中进行编码A,同时,标签输入到Embedding层也进行了映射编码B;然后将这两种编码进行拼接,生成C。接着,mask部分也通过Embedding层进行映射编码,并将这个mask编码和C进行逐个元素相加,输出D特征,然后将D特征输入到Transformer中进行序列化编码。编码完成后,通过全连接层映射类别输入,只要激活值大于0.5,则表示网络选择了该类别。
测试阶段,label不再参入,同时为了保证输入维度的一致性,对mask进行升维度,同时将值全部置0,然后进行后续的inference。
对于mask的生成的理解
首先,positive是指A目标和B目标更有可能共存;negative是指A目标和B目标更不可能共存。
其次是标注的问题。这里判断两个目标是否共存,可以借助已有的一些先验知识,如 标注框或者是某些实物的出现可能要weather等一些外部因素的影响,来标注多个目标在同一张图中出现的概率情况。