Learning to Discretely Compose Reasoning Module Networks for Video Captioning
1. 概要
- 发表:IJCAI 2020
- 代码:https://github.com/tgc1997/RMN
- idea:作者认为视频描述的生成是step-by-step的。对于一个句子的生成,首先需要定位和描述主语subject,接着推理动作,然后定位和描述宾语object。而这样一个过程,作者认为是需要复杂的时空推理。对于推理模块,作者设计了三个模块locate,relate,func,分别用于定位目标(2D),推理关系(3D)以及一些连词的生成(如a、the、and);对于选择模块,作者设计了Module Selector用于在生成下一个单词的时候选择上述模块中的一种。
2. 详细设计
2.1 Encoder
- 特征提取:分别使用2D-CNN, 3D-CNN, R-CNN提取了视频的appearance feature V a V_a Va, motion feature V m V_m Vm,object feature V o V_o Vo。注意这里的 V o V_o Vo是具有位置信息的(代码中有体现)
- 特征处理:对于 V a V_a Va和 V m V_m Vm,作者分别使用了Bi-LATMs处理以在特征中融入时间信息。
- 整个网络的指导信息
h
t
e
n
h_t^{en}
hten:LSTM的隐层输出。输入是全局视觉信息
v
ˉ
\bar v
vˉ,上一step生成的最后一个单词的embedding以及隐层状态
2.2 Reasoning Modules
所有的推理模块都是基于下面这个attention计算(Neural machine translation by jointly learning to align and translate.ICLR 2015)
这种方式定义的attention可以沿着指定纬度,为了更好的对空间和时间方向建模,作者分别定义了时间纬度和空间纬度上的attention:
A
o
S
(
⋅
)
AoS(\cdot)
AoS(⋅)和
A
o
T
(
⋅
)
AoT(\cdot)
AoT(⋅)
- Locate Module
主要是为了生成object words,如“man”、“basketball”等。需要模块在时间和空间上关注region信息,因此作者先将 V o V_o Vo送进 A o S ( ⋅ ) AoS(\cdot) AoS(⋅),然后再和 V a V_a Va一起送进 A o T ( ⋅ ) AoT(\cdot) AoT(⋅)
这里的 ⨁ \bigoplus ⨁表示concate操作 - Relate Module
主要是为了生成动词,例如“shoting”、“riding”等。在如下图所示的图片中,为了生成动词“shoting”,模型需要注意到不同场景中object状态的变化,因此在Relate Module中对任意的每一对空间attention处理后的 V o V_o Vo进行了配对,然后再执行时间attention
- Func Module
主要是为了生成一些连词使整个句子连贯,如“of”,“and”等。这里不需要视觉信息,只需要语言信息,因此对decoder LSTM的历史cell states执行AoT
可以发现这三个模块都是紧紧围绕着这一小节最开始提到的attention操作进行的,将 h t e n h_t^{en} hten作为attention的Q。
Module Selector
在生成模块中,每一个step生成的word只能是上述三个模块中的一种,因此需要设计一个选择模块进行选择。具体实现是对每一个模块进行打分,然后选择最高分。打分函数设计如下:
但是由于max函数不可微,所以作者使用了一种近似方法将one-hot vector
z
t
z_t
zt转换为连续的值
z
t
~
\tilde {z_t}
zt~
最终的视觉推理结果为:
这里的
⨂
\bigotimes
⨂表示inner product
Decoder
用了个LSTM进行解码,输入为视觉信息结果
v
t
v_t
vt,encoder的隐层
然后将视觉信息、隐层信息接一个MLP输出对应词典的概率分布得到生成的word
Training
- Caption Loss:cross-entropy loss
用于衡量生成句子的准确度
T T T表示句子长度 - POS Loss:KLD loss
用于衡量选择模块的准确性,具体是将句子的POS转换为one-hot编码,然后使用KLD(Kullback-Leibler Divergence) loss来衡量两个分布的相似度。实际在代码中实现也是用的cross-entropy loss
- 最终的loss