CNN-RNN: A Unified Framework for Multi-label Image Classification
Paper PDF
文章目录
Introduction
随着大规模数据集的出现和深度神经网络的进步,图像分类的精度越来越高。但通常一副图像往往含有多种元素,包括对象,部分,场景,动作等等。对图像中丰富的语义信息及其依赖关系进行建模是图像理解的基础。传统的图像分类任中,每一张图片只有一个标签,而在多标签分类任务中一张图像往往具有若干个与之相关的标签。这种任务显然能够让计算机进一步理解图像内容。
在以往的多标签分类的解决办法中,一类常见的解决方法是讲多标签分类任务转化为单标签分类,如利用利用rank-loss或者cross-entropy loss训练的神经网络。这些方法往往没有考虑不同标签之间的依赖关系。当然现在也存在基于图模型的办法来对标签之间的关系进行建模,如利用马尔科夫随机场或共现概率等方法来推断标签对的联合概率。但这类方法仅仅考虑了标签之间的low-level关系,在面对标签数量增多的情况下,往往会导致计算量巨大。甚至在存在相似标签,如“cat" 和”kitten“的时候,就会导致大量的冗余信息。
本文提出了一种用于多标签图像分类的统一的CNN-RNN框架,该框架以端到端方式有效地学习了语义冗余和依赖关系。为了获取标签之间的high-level关系,本文利用LSTM显式地建模label的语义信息以及依赖关系,同时这种RNNs框架利用前一阶段的预测标签来动态调整输入图像的提取特征,这使得在预测不同标签时,网络能够注意不同的图像区域。在标签存在的冗余语义问题上,论文提出构建图像-标签的联合embedding。这种embedding将每一个标签或图像向低维欧几里得空间进行映射,以使语义相似标签彼此接近,并且每个图像的应与其关联的标签接近。
Innovation
- joint image/label embedding
- 利用RNN学习在joint image/label embedding 空间下不同标签间的co-occurrence dependency。
Method
Model
模型通过CNN来提取图像的特征,标签则映射到相同的低维空间。如下式:
w
k
=
U
l
⋅
e
k
(1)
w_k = U_l \cdot e_k \tag{1}
wk=Ul⋅ek(1)
其中,
e
k
e_k
ek表示k-th标签的one-hot向量,
U
l
U_l
Ul则表示标签映射矩阵(label embedding matrix),其中k-th行表示k-th标签的embedding 。(我觉得论文中表述有问题,应该是
e
k
⋅
U
l
e_k \cdot U_l
ek⋅Ul,得到一个(1 * n)的k-th标签的embedding)
将上一预测的标签embedding输入LSTM,通过其内部的非线性函数以及其隐藏递归状态对标签的共现依赖关系进行建模。
r
(
t
)
=
h
r
(
r
(
t
−
1
)
,
w
k
(
t
)
)
o
(
t
)
=
h
o
(
r
(
t
−
1
)
,
w
k
(
t
)
)
(2)
r(t) = h_r (r(t-1), w_k(t)) \\ o(t) = h_o(r(t-1), w_k(t)) \tag{2}
r(t)=hr(r(t−1),wk(t))o(t)=ho(r(t−1),wk(t))(2)
其中,
r
(
t
)
r(t)
r(t)和
o
(
t
)
o(t)
o(t)分别表示t时刻的隐藏状态以及输出。
w
k
(
t
)
w_k(t)
wk(t)是预测路径中第t个标签的embedding。
递归层的输出和图像特征做融合并被投影到与标签embedding相同的低维空间中。
x
t
=
h
(
U
o
x
o
(
t
)
+
U
I
x
I
)
(3)
x_t = h(U_o^xo(t) + U_I^xI) \tag{3}
xt=h(Uoxo(t)+UIxI)(3)
其中,
U
o
x
U_o^x
Uox和
U
I
x
U_I^x
UIx分别表示循环层输出和图像特征映射矩阵。
最后,通过将
x
t
x_t
xt和
U
l
U_l
Ul的转置相乘来计算标签分数,结果代表了
x
t
x_t
xt和每个标签embedding之间的距离。
s
(
t
)
=
U
l
T
x
t
(4)
s(t) = U_l^Tx_t \tag{4}
s(t)=UlTxt(4)
通过对标签得分softmax并归一化,可以计算出预测的标签概率。
Training
训练多标签CNN-RNN模型的一个重要问题是确定标签的顺序。本文中标签的顺序根据它们在训练数据中的出现频率来确定。频繁出现的标签比不频繁出现的标签出现得早,这与直觉相一致,即应该先预测较容易的物体,以帮助预测较难的物体。
Inference
通过beam search算法来寻找最大概率标签集合。
Experiment
Compare with SOTA
- 总体性能上超过SOTA,对大目标和相关性较强的目标检测性能好
- 很难检测相关性较小的小物体标签
Label embedding
embedding相近的标签具有较强的语义相关性。
Attention Visualization
虽然RNN框架没有学习明确的注意力模型,但在对不同目标进行分类时,它能够将注意力转向不同的图像区域。