最近在看多模态内容,记录一下文图模型中常用的损失函数。最先提出ITC loss的是论文ALBEF,下面是文章对该Loss的定义
假设有输入图片 I 经过image encoder之后变成{ v c l s , v 1 , … , v N v_{cls}, v_1, …, v_N vcls,v1,…,vN},输入文本 T 经过 text encoder 后变成{ w c l s , w 1 , … , w N w_{cls}, w_1,…, w_N wcls,w1,…,wN}
ITC Loss 的全称是 Image-Text Contrastive Loss ,为了在融合之前学习更好的unimodal表示,它学习 s = g v ( v c l s ) T g w ( w c l s ) s = g_v (v_{cls})^T g_w(w_{cls}) s=gv(vcls)Tgw(wcls),这里的 g v g_v gv和 g w g_w gw函数是给cls token embedding降维的线性层。另一方面,文图对会进入一个momentum unimodal encoders(这个结构的作用是通过结合过去更新中积累的知识,帮助稳定和提高学习表示的质量),变成 g ′ v ( v ′ c l s ) 和 g ′ w ( w ′ c l s ) g′_v (v′_{cls}) 和g′_w(w′_{cls}) g′v(v′cls)和g′w(w′cls)
定义:
s
(
I
,
T
)
=
g
v
(
v
c
l
s
)
T
g
′
w
(
w
′
c
l
s
)
s
(
T
,
I
)
=
g
w
(
w
c
l
s
)
T
g
′
v
(
v
′
c
l
s
)
s(I, T) = g_v (v_{cls})^T g′_w(w′_{cls}) \\ s(T, I) = g_w(w_{cls})^Tg′_v (v′_{cls})
s(I,T)=gv(vcls)Tg′w(w′cls)s(T,I)=gw(wcls)Tg′v(v′cls)
对于每个图像和文本,我们计算softmax归一化的图像到文本和文本到图像相似度为:
其中的
τ
\tau
τ是可学习的参数。令onehot相似度的真实值是
y
i
2
t
(
I
)
y^{i2t} (I)
yi2t(I) 和
y
t
2
i
(
T
)
y^{t2i}(T)
yt2i(T),真值中负样本对的概率为0,正样本对的概率为1。
ITC loss为
p
p
p和
y
y
y的交叉熵:
ALBEF代码中ITC loss对应的主要代码:
sim_i2t = image_feat @ text_feat_all / self.temp
sim_t2i = text_feat @ image_feat_all / self.temp
# image_feat和text_feat分别是图片和文本特征,text_feat_all和image_feat_all是从momentum encoder中取出来的文本、图像特征
# self.temp = nn.Parameter(torch.ones([]) * config['temp']) ,引入一个可学习的参数,可以对计算的结果进行缩放,从而调整模型
loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()
# F.log_softmax(sim_i2t, dim=1)对sim_i2t的每一行进行log_softmax计算
# sim_i2t和sim_i2t_targets的形状一样,sim_i2t_targets是真实值
# F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets是矩阵按元素相乘
loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean()
# loss_t2i中的操作同上
loss_ita = (loss_i2t+loss_t2i)/2 #求平均得到ITC Loss
之后在更新同一篇文章中的Image-Text Matching (ITM) loss。