Localized Contrastive Estimation LCE
论文链接: https://arxiv.org/abs/2101.08751
code: https://github.com/luyug/Reranker/
1.abstract
Pre-trained deep language models (LM) have advanced the state-of-the-art of text retrieval. Rerankers fine-tuned from deep LM estimates candidate relevance based on rich contextualized matching signals. Meanwhile, deep LMs can also be leveraged to improve search index, building retrievers with better recall. One would expect a straightforward combination of both in a pipeline to have additive performance gain. In this paper, we discover otherwise and that popular reranker cannot fully exploit the improved retrieval result. We, therefore, propose a Localized Contrastive Estimation (LCE) for training rerankers and demonstrate it significantly improves deep two-stage models.
摘要中主要提到了:目前预训练语言模型已经在文本检索任务中sota,基于预训练语言模型微调的reranker能够基于丰富的上下文特征预测相关性。同时,预训练语言模型还可以用来改进搜索索引,从而构建具有更好召回率的检索器。研究者期望结合这两种方法可以获得更好的检索性能。痛点是目前流行的reranker不能充分利用检索器给出的检索结果。因此,提出了一种用于训练reranker的局部对比预测(LCE)方法,并证明它能够显著提升两阶段模型的性能。
2.methodologies
2.1 Preliminaries
使用Bert预测Q-D pair的相关性得分: s = s c o r e ( q , d ) = v p t c l s ( B E R T ( c o n c a t ( q , d ) ) s=score(q,d)=v_p^t cls(BERT(concat(q,d)) s=score(q,d)=vptcls(BERT(concat(q,d))
2.2 Vanilla method
提到了两种对比方法,通过独立采样Q-D pairs,然后通过相关性得分和对应的标签
(
+
/
−
)
(+/-)
(+/−)计算二分类交叉熵(BCE)作为损失函数计算:
L
v
:
=
{
B
C
E
(
s
c
o
r
e
(
q
,
d
)
,
+
)
,
d is positive
B
C
E
(
s
c
o
r
e
(
q
,
d
)
,
−
)
,
d is negative
L_v:= \begin{cases} BCE(score(q,d),+), & \text {d is positive} \\ BCE(score(q,d),-), & \text{d is negative} \end{cases}
Lv:={BCE(score(q,d),+),BCE(score(q,d),−),d is positived is negative
传统方法是通过将重排问题作为一般的二分类问题处理。但本文提出,reranker是不同的,因为在重排问题中检索结果的top序列包含了更多易于混淆的信息(confounding signatures)(理解为Q-D pair之间的相似度差异不显著,因此易造成模型性能不佳)。作者提出,reranker应该具有以下能力:
- 能够处理检索出的top序列
- 避免在处理混淆特征时崩溃
基于此作者提出了Localized Contrastive Estimation(LCE) loss,这一损失函数将焦点放在了检索结果的top序列上,并能够有效避免混淆信息造成的崩溃。
2.3 Localized Negatives from Target Retriever
给定初始阶段检索器,训练query集合,使用检索器从doc语料库中,生成query集合中每一个query对应的按照相关性排序的doc序列。从这个序列的top m序列 R q m R_q^m Rqm中采样n个负样本作为负样本集。
2.4 Contrastive Loss
在从检索器检索结果中聚合得到负样本集后,对于query集合形成一组样本
G
q
Gq
Gq,其中的元素为每一个query对应的正样本
d
q
+
d_q^+
dq+和从
R
q
m
R_q^m
Rqm采样的n个负样本组成的负样本集。使用Bert计算Q-D相关性得分:
d
i
s
t
(
q
,
d
)
=
s
c
o
r
e
(
q
,
d
)
=
v
p
t
c
l
s
(
B
E
R
T
(
c
o
n
c
a
t
(
q
,
d
)
)
dist(q,d)=score(q,d)=v_p^t cls(BERT(concat(q,d))
dist(q,d)=score(q,d)=vptcls(BERT(concat(q,d))
定义一个query的contrastive loss:
L
q
:
=
−
l
o
g
e
x
p
(
d
i
s
t
(
q
,
d
+
)
)
∑
d
∈
G
q
e
x
p
(
d
i
s
t
(
q
,
d
)
)
L_q:=-log\frac{exp(dist(q,d^+))}{\sum_{d\in G_q}exp(dist(q,d))}
Lq:=−log∑d∈Gqexp(dist(q,d))exp(dist(q,d+))
2.5 LCE Batch Update
将上述损失放在一起,定义在一个有多个query的batch上: L L C E : = 1 ∣ Q ∣ ∑ q ∈ Q , G q ∽ R q m − l o g e x p ( d i s t ( q , d + ) ) ∑ d ∈ G q e x p ( d i s t ( q , d ) ) L_{LCE}:=\frac{1}{\lvert Q\rvert}\sum_{q\in Q,G_q\backsim R_q^m } -log\frac{exp(dist(q,d^+))}{\sum_{d\in G_q}exp(dist(q,d))} LLCE:=∣Q∣1q∈Q,Gq∽Rqm∑−log∑d∈Gqexp(dist(q,d))exp(dist(q,d+))
3. Experiment Methodologies
- Dataset:MSMARCO document ranking dataset
- Initial Stage Retriever:Indri,un-turned BM25,turned BM25(BM25*),HDCT