学习对比反事实样本,以实现稳健的视觉问答
Learning to Contrast the Counterfactual Samples for Robust Visual Question Answering
在阅读本文之前,一定要阅读论文:Counterfactual Samples Synthesizing for Robust Visual Question Answering(简称CSS)
方法
文章的方法主要包括三个部分:(1)一个基本的VQA模型。(2)一个事实和反事实样本合成(CSS)模块。(3)一个对比学习(CL)目标。
第一部分和第二部分都属于CSS已经实现的,主要作用在于
(1)并通过多分类的方法预测答案,并产生图中右上方基本VQAloss
(2)得到(I, I+, I-)和(Q, Q+, Q-),
第三部分
以(I, I+, I-)为例,将(I, I+, I-)和Q喂给VQA模型,分别产生原始样本的嵌入mm(V, Q)作为anchor(a),事实样本的嵌入mm(V+, Q)作为positive(p),反事实样本嵌入mm(V-, Q)作为negati(n)
利用余弦相似度作为评分函数,对正样本输出高值,对负样本输出低值,公式如下:
同样的方法得到anchor和negative之间的评分s(a, n), 这就相当于图中展示的,拉近原始图像与事实区域图像的关系,推远原始图像与反事实区域的距离。
对比损失定义为:(这就是图片下方得到的Contrastive loss)
最后,这种对比损失与基础分类损失的加权总和弥补了整体损失:
虽然文章说,这种方法能够使模型学习他们之间的关系,并从更有因果关系的方面预测正确答案。但是,个人感觉如果仅仅使以上方法,并不能从理论上提高模型的能力。