很早就读过SentenceBERT这篇论文,后来工作中也用到了,最近重新看这份工作的时候发现有关SentenceBERT(SBERT)的东西都忘得差不多了,所以今天做一个回顾。
论文地址: Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks;
github地址:UKPLab/sentence-transformers
SBERT本质上还是基于BERT的变体,他得到的句向量就是BERT词向量的池化。
介绍
SBERT是基于BERT来获取句向量表示的模型,扩展了BERT的应用。常用于句子分类和句子相似度的计算等。
背景
BERT和RoBERTa在句子对任务上取得了sota的表现,但是每次计算,都需要将两个句子一起输入,这会导致巨大的计算开销。所以BERT其实并不适用于语义相似性搜索以及聚类等无监督任务。
SBERT是一个BERT的变体,用连体网络(siamese network)或三元组网络(triplet network)结构来生成语义上有意义的句子嵌入,可被用于余弦相似度的比较。在大规模的语义相似度比较,聚类和通过语义搜索的信息检索等任务上能够大大降低计算开销。
模型
SBERT通过在BERT/RoBERTa的输出上加一层pooling操作,来生成固定大小的句子嵌入。pooling操作主要有三种:
- 用CLS token的输出(CLS Pooling);
- 所有token向量的均值(mean pooling);
- 所有token向量中的最大值(max pooling);
默认为平均池化。基于连体网络(siamese network)或三元组网络(triplet network)来微调BERT/RoBERTa。主要用三个训练目标:
句子分类任务
给出两个句子
u
u
u,
v
v
v,让模型预测两个句子是否相似。分别将
u
u
u和
v
v
v输入到两个相同的BERT中,得到句向量。计算
u
u
u和
v
v
v的element-wise difference
∣
u
−
v
∣
|u−v|
∣u−v∣,然后将
u
u
u,
v
v
v,
∣
u
−
v
∣
|u−v|
∣u−v∣连接,输入给softmax,计算概率分布。
W
t
∈
R
3
n
∗
k
W_t ∈ R^{3n*k}
Wt∈R3n∗k
o = s o f t m a x ( W t ( u , v , ∣ u − v ∣ ) ) o = softmax(W_t(u, v, |u − v|)) o=softmax(Wt(u,v,∣u−v∣))
n
n
n为句子嵌入的维度,
k
k
k为标签类别数,用交叉熵损失优化。
句子回归任务
给出两个句子
u
u
u,
v
v
v,计算他们之间的相似度。和上面一样,将两个句子输入到两个相同的BERT中,得到句向量。然后直接计算二者的余弦相似度。用均方误差损失函数优化。
句子三元组任务
给定一个主句
a
a
a, 一个正样例句(positive sentence)
p
p
p和负样例句(negative sentence)
n
n
n,调整网络使得
a
a
a和
p
p
p之间的距离小于
a
a
a和
n
n
n之间的距离。最小化一下损失函数:
m
a
x
(
∣
∣
s
a
−
s
p
∣
∣
−
∣
∣
s
a
−
s
n
∣
∣
+
ϵ
,
0
)
max(|| s_a - s_p || - || s_a - s_n || + \epsilon , 0)
max(∣∣sa−sp∣∣−∣∣sa−sn∣∣+ϵ,0)
s
x
s_x
sx表示句子嵌入,
∣
∣
⋅
∣
∣
|| · ||
∣∣⋅∣∣表示距离(文中用欧氏距离),
ϵ
\epsilon
ϵ表示正样本
p
p
p和句子
a
a
a的句子最小要比负样本
n
n
n到
a
a
a的距离近
ϵ
\epsilon
ϵ。(是一种对比学习的方案)
这里主要做模型的简单介绍,完整细节和实验结果可以参考原文。
实战
SBERT有现成的库可以用,也提供了各种版本的模型:
from sentence_transformers import SentenceTransformer, util
model = SentenceTransformer('roberta-large-nli-stsb-mean-tokens')
# bert-base-nli-cls-token,bert-base-nli-mean-token等
我用的版本是roberta-large-nli-stsb-mean-tokens,也就是基于roberta-large版本的预训练模型,在NLI和STSB数据集上做微调的SBERT,且用均支池化策略来计算句子表示。
sentence1 = 'It was a great day'
sentence2 = 'Today was awesome'
# model.encode()
sentence1_embed = model.encode(sentence1)
sentence2_embed = model.encode(sentence2)
# 可以直接用util中给出的余弦相似度计算函数来计算
cos_sim = float(util.cos_sim(embedding, embed)[0][0])
print(cos_sim ) # cos_sim 为相似度值
这里放一下计算余弦相似度的方法代码:
def cos_sim(a: Tensor, b: Tensor):
"""
Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j.
:return: Matrix with res[i][j] = cos_sim(a[i], b[j])
"""
if not isinstance(a, torch.Tensor):
a = torch.tensor(a)
if not isinstance(b, torch.Tensor):
b = torch.tensor(b)
if len(a.shape) == 1:
a = a.unsqueeze(0)
if len(b.shape) == 1:
b = b.unsqueeze(0)
a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
return torch.mm(a_norm, b_norm.transpose(0, 1))
当然计算余弦相似度有很多种方法,比如pytorch官方文档中给出的torch.nn.functional.cosine_similarity(),spacy 库的spatial.distance.cosine(),他们的计算结果都是一样的,之后会再写他们的详细介绍。
参考:Sentence-BERT实战
如有错误欢迎指正!