一、CoSENT简介
目前已经有了很多比较有效的bert系列句向量方案,其中最易于实现以及数据构造比较友好并且效果也比较好的方法就是Sentence-BERT。当然Sentence-BERT的训练目标和推理任务的目标是不一致的,它的训练目标是优化一个分类任务,而推理却是计算文本的余弦相似度。那么有没有一种更好的方案来解决Sentence-BERT文本相似度训练和推理优化目标不一致的问题?苏神的博客——CoSENT(一):比Sentence-BERT更有效的句向量方案给出了答案,直接优化文本对的余弦相似度同label的差异!
以前在做文本相似度任务的时候,自己也尝试过直接优化cos相似度和label的差异,采用mse损失函数,最后的结果往往是很差的。苏神提出的CoSENT比较好的解决了这里不一致的问题,同时也提高了文本相似度任务的准确率。CoSENT本质上就是一种新的损失函数,它统计了所有正样本对余弦值与负样本对余弦值的差值同样本标签差值的损失,模型训练的时候就不会崩掉,也能提高文本相似度任务的准确率。具体的理论需要读者自己去详细阅读苏神的博客——CoSENT(一):比Sentence-BERT更有效的句向量方案。
以上公式就是苏神博客中的截图,只需要正本对的余弦值都大于负样本对的余弦值就能保证loss很小,文本相似度就很有区分度。这里也有一点对比学习的意味在里面。
CosEnt训练阶段
CosEnt推理阶段
这样就使得训练和推理阶段一致了
二、pytorch实现
这里的cosent_loss的实现参考了苏神的keras版本和shawroad同学的CoSENT_Pytorch版本的代码
def cosent_loss(output,y_true):
"""
:param output: 模型的输出对于文本对队里a和队里b的cos值,[a,b,c]--[a1,b1,c1]-->[cos<a,a1>,cos<b,b1>,cos<c,c1>]
:param y_true: 文本[a,b,c]--[a1,b1,c1]对应的有监督标签<a,a1>/<b,b1>/<c,c1>
:return: loss
参数 lamda 直接取20
"""
output = output*20
# 4. 取出负例-正例的差值
#利用了矩阵计算的广播机制
y_pred = output[:, None] - output[None, :] # 这里是算出所有位置 两两之间余弦的差值
# 矩阵中的第i行j列 表示的是第i个余弦值-第j个余弦值
y_true = y_true[:, None] < y_true[None, :] # 取出负例-正例的差值
y_true = y_true.float()
y_pred = y_pred - (1 - y_true) * 1e12 #这里之所以要这么减,是因为公式中所有的正样本对的余弦值减去负样本对的余弦值才计算损失,把不是这些部分通过exp(-inf)忽略掉
y_pred = y_pred.view(-1)
if torch.cuda.is_available():
y_pred = torch.cat((torch.tensor([0]).float().cuda(), y_pred), dim=0) # 这里加0是因为e^0 = 1相当于在log中加了1
else:
y_pred = torch.cat((torch.tensor([0]).float(), y_pred), dim=0) # 这里加0是因为e^0 = 1相当于在log中加了1
return torch.logsumexp(y_pred, dim=0)
注意output是文本对向量化后计算的余弦相似度;同时代码中的这几句需要结合公式看更好理解:
y_pred = output[:, None] - output[None, :] # 这里是算出所有位置 两两之间余弦的差值
# 矩阵中的第i行j列 表示的是第i个余弦值-第j个余弦值
y_true = y_true[:, None] < y_true[None, :] # 取出负例-正例的差值
y_true = y_true.float()
y_pred = y_pred - (1 - y_true) * 1e12 #这里之所以要这么减,是因为公式中所有的正样本对的余弦值减去负样本对的余弦值才计算损失,把不是这些部分通过exp(-inf)忽略掉
计算loss的时候公式中只是计算了所有的正样本对的余弦值-负样本对的余弦值,其余部分不参与计算。
三、中文实验
本次实验使用paws_x数据集,该数据集中的句子对是否相似都是模型比较难处理的文本。训练集4.3W,验证集2K,测试集2K。
数据如下:
预训练模型chinese-bert-wwm-ext、epochs=20、batch_size=128、lr=1e-5
模型代码:
CosEnt
import torch.nn as nn
from transformers import BertConfig, BertModel
from transformers import BertPreTrainedModel
import torch
class CosEnt(BertPreTrainedModel):
def __init__(self,config,max_len,tokenizer,device,task_type):
super(CosEnt,self).__init__(config)
self.max_len = max_len
self.task_type = task_type
self._target_device = device
self.tokenizer = tokenizer
self.bert = BertModel(config=config)
def forward(self,inputs):
input_a = inputs[0]
input_b = inputs[1]
output_a = self.bert(**input_a,return_dict=True, output_hidden_states=True)
output_b = self.bert(**input_b,return_dict=True, output_hidden_states=True)
#采用最后一层
embedding_a = output_a.hidden_states[-1]
embedding_b = output_b.hidden_states[-1]
embedding_a = self.pooling(embedding_a,input_a)
embedding_b = self.pooling(embedding_b, input_b)
d = torch.mul(embedding_a,embedding_b)
a_len = torch.norm(embedding_a,dim=1)
b_len = torch.norm(embedding_b,dim=1)
cos = torch.sum(d,dim=1)/(a_len*b_len)
output = cos
return output
def pooling(self,token_embeddings,input):
output_vectors = []
#attention_mask
attention_mask = input['attention_mask']
#[B,L]------>[B,L,1]------>[B,L,768],矩阵的值是0或者1
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
#这里做矩阵点积,就是对元素相乘(序列中padding字符,通过乘以0给去掉了)[B,L,768]
t = token_embeddings * input_mask_expanded
#[B,768]
sum_embeddings = torch.sum(t, 1)
# [B,768],最大值为seq_len
sum_mask = input_mask_expanded.sum(1)
#限定每个元素的最小值是1e-9,保证分母不为0
sum_mask = torch.clamp(sum_mask, min=1e-9)
#得到最后的具体embedding的每一个维度的值——元素相除
output_vectors.append(sum_embeddings / sum_mask)
#列拼接
output_vector = torch.cat(output_vectors, 1)
return output_vector
Sbert
import torch.nn as nn
from transformers import BertConfig, BertModel
from transformers import BertPreTrainedModel
import torch
class SentenceBert(BertPreTrainedModel):
def __init__(self,config,max_len,tokenizer,device,task_type):
super(SentenceBert,self).__init__(config)
self.max_len = max_len
self.task_type = task_type
self._target_device = device
self.tokenizer = tokenizer
self.bert = BertModel(config=config)
self.classifier = nn.Linear(3*config.hidden_size,config.num_labels)
def forward(self,inputs):
input_a = inputs[0]
input_b = inputs[1]
output_a = self.bert(**input_a,return_dict=True, output_hidden_states=True)
output_b = self.bert(**input_b,return_dict=True, output_hidden_states=True)
#采用最后一层
embedding_a = output_a.hidden_states[-1]
embedding_b = output_b.hidden_states[-1]
embedding_a = self.pooling(embedding_a,input_a)
embedding_b = self.pooling(embedding_b, input_b)
if self.task_type =="classification":
embedding_abs = torch.abs(embedding_a-embedding_b)
vectors_concat = []
vectors_concat.append(embedding_a)
vectors_concat.append(embedding_b)
vectors_concat.append(embedding_abs)
#列拼接3个768————>3*768
features = torch.cat(vectors_concat, 1)
output = self.classifier(features)
else:
d = torch.mul(embedding_a,embedding_b)
a_len = torch.norm(embedding_a,dim=1)
b_len = torch.norm(embedding_b,dim=1)
cos = torch.sum(d,dim=1)/(a_len*b_len)
output = cos
return output
def pooling(self,token_embeddings,input):
output_vectors = []
#attention_mask
attention_mask = input['attention_mask']
#[B,L]------>[B,L,1]------>[B,L,768],矩阵的值是0或者1
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
#这里做矩阵点积,就是对元素相乘(序列中padding字符,通过乘以0给去掉了)[B,L,768]
t = token_embeddings * input_mask_expanded
#[B,768]
sum_embeddings = torch.sum(t, 1)
# [B,768],最大值为seq_len
sum_mask = input_mask_expanded.sum(1)
#限定每个元素的最小值是1e-9,保证分母不为0
sum_mask = torch.clamp(sum_mask, min=1e-9)
#得到最后的具体embedding的每一个维度的值——元素相除
output_vectors.append(sum_embeddings / sum_mask)
#列拼接
output_vector = torch.cat(output_vectors, 1)
return output_vector
CosEnt训练采用cosent_loss,而不是mse_loss。Sbert训练直接采用分类任务的F.cross_entropy(output,labels);评测的时候计算文本相似度,选择最佳的阈值:
def predict(model,test_dataloader,device):
model.task_type = "cossimlarity"
cossimis = []
for step, batch in enumerate(test_dataloader):
batch = [t.to(device) for t in batch]
inputs_a = {'input_ids': batch[0], 'attention_mask': batch[1], 'token_type_ids': batch[2]}
inputs_b = {'input_ids': batch[3], 'attention_mask': batch[4], 'token_type_ids': batch[5]}
inputs = []
inputs.append(inputs_a)
inputs.append(inputs_b)
cossimi = model(inputs).detach().cpu().tolist()
cossimis.extend(cossimi)
test_df = pd.read_csv('data/paws_x/test_2k.tsv',sep='\t')
labels = test_df['label'].values.tolist()
bestacc = 0
bestthreshold = 0
thresholds = np.arange(0,1,0.01)
bestpreds = []
for threshold in thresholds:
preds = [ 1 if cos >= threshold else 0 for cos in cossimis]
correct = 0
for i in range(len(labels)):
if labels[i] == preds[i]:
correct += 1
acc = correct/len(labels)
if acc > bestacc:
bestacc = acc
bestthreshold = threshold
bestpreds = preds.copy()
print('Sbert acc: %.4f bestthreshold: %.4f'%(bestacc,bestthreshold))
test_df['preds'] = bestpreds
test_df['cossim'] = cossimis
test_df.to_csv('pawsx_Sbert.csv',sep='\t',index=False)
最后的结果如下:
CosEnt部分测试集数据cossim指的是相似度,preds为预测标签(相似度大于阈值就为1否则就是0)
id sentence1 sentence2 label preds cossim
10 2005 年末至 2009 年期间是例外,当时他效力于瑞典的卡斯塔德联队、塞尔维亚的查查克足球俱乐部和俄罗斯的格罗兹尼特里克足球俱乐部。 例外情况发生于 2005 年末至 2009 年期间,当时他效力于瑞典的卡斯塔德联队、塞尔维亚的查查克足球俱乐部和俄罗斯的格罗兹尼艾卡马特足球俱乐部。 1 1 0.9259706735610962
12 Tabaci 河是罗马尼亚 Leurda 河的支流。 Leurda 河是罗马尼亚境内 Tabaci 河的一条支流。 0 0 0.06759580224752426
20 1993 年,他为 A 级的坎恩郡美洲狮队和 AA 级的波特兰海狗队效力。 1993 年,他为 A 级球队波特兰海狗队和 AA 级球队凯恩县美洲狮队效力。 0 0 0.3313765823841095
26 Winarsky 是 IEEE、Phi Beta Kappa、ACM 和 Sigma Xi 的成员。 温那斯基是 ACM、IEEE、Phi Beta Kappa 和 Sigma Xi 的成员。 1 1 0.8139660954475403
27 1938 年,他成为英埃苏丹的政府人类学家,并领导对努巴的实地考察工作。 1938 年,他成为英埃苏丹政府的人类学家,并与努巴一起从事野外工作。 0 1 0.7984423041343689
28 比利·比利·贝特森出现在 2008 年末至 2009 年初出版的前四期《黑亚当》中。 黑亚当出现在 2008 年末至 2009 年初出版的前四期《比利·贝特森》中。 0 0 0.7027683258056641
32 利用太阳能满足此项要求的方法是在常规动力飞机上使用太阳能板。 利用太阳能满足此项要求的方法是在常规动力飞机上使用太阳能板。 1 1 0.9999999403953552
35 在调查进行期间,警察还质询了歌手梨美·托米和演员卡薇雅·马德哈万,两人均为西迪基及其妻子迪利普的好友。 作为持续进行的调查的一部分,警察还询问了歌手 Rimi Tomy 和演员 Kavya Madhavan,二人都是西迪基和他的妻子迪利普的好友。 1 0 0.43909767270088196
36 它们被稀疏的现场管弦乐覆盖,并由模糊不清和几乎故意平淡的声音构成。 它们被毫无音色的现场播放的音乐覆盖,并且由模糊不清、几乎故意稀疏的管弦乐声音构成。 0 0 0.7276858687400818
38 霍利在音乐上受到了埃尔顿·约翰的影响。 霍利 霍利在音乐上受到艾尔顿·约翰的影响。 1 1 0.9021589756011963
39 球队在 2 月 19 日当晚对下一场比赛的变化作出了回应。 该队在第二天(2 月 19 日)晚上的同一场比赛中应对变化。 0 0 0.7320891618728638
40 Nashua Silver Knights 队是当前夏季联赛的一部分,也是该市的大学体育队。 纳舒厄白银骑士团队加入了夏季大学联盟,是本市的现役球队。 0 0 0.7255858778953552
48 “断头台” 1949 年最后一次在西德使用,1966 年最后一次在东德使用。 Fall Beil 于 1949 年最后一次在西德使用并于 1966 年在东德使用。 1 1 0.9291356801986694
51 此运河为比利时,乃至欧洲,历史最为悠久的可通航运河之一。 这条运河是比利时(确切来说是欧洲)最古老的通航运河之一。 0 1 0.8468931913375854
52 他在 2009 年搬回了费城,现在住在纽约市。 他于 2009 年搬回费城,现居住在纽约市。 1 1 0.9845840930938721
53 1954 年 6 月 30 日,他因胃癌于俄克拉荷马州克莱摩尔病逝,而林恩·瑞格斯纪念馆则位于纽约市。 1954 年 6 月 30 日,他因胃癌于纽约市病逝。而林恩·瑞格斯纪念馆则位于俄克拉荷马州克莱摩尔。 0 0 0.01974690519273281
55 什蒂普西奇出生于德国科恩堡,在维也纳斯塔莫斯多夫度过了他的童年。 Stipsits 出生于科尔新堡,并在维也纳施塔莫斯多夫度过了他的童年。 1 0 0.7667046785354614
67 凯塔王朝从 12 世纪至 17 世纪初一直统治着前帝国及帝国时期的马里。 凯塔王朝在 12 世纪至 17 世纪统治了马里,在帝国成立前和帝国成立后都统治了。 1 0 0.5772057175636292
68 “天使之眼”是 1946 年的一首流行歌曲,由马特·丹尼斯作曲,厄尔·布伦特作词。 《Angel Eyes》是 1946 年的一首由 Earl Brent 作曲 Matt Dennis 作词的流行歌曲。 0 0 0.32964301109313965
Sbert 部分测试集数据cossim指的是相似度,preds为预测标签(相似度大于阈值就为1否则就是0)
id sentence1 sentence2 label preds cossim
10 2005 年末至 2009 年期间是例外,当时他效力于瑞典的卡斯塔德联队、塞尔维亚的查查克足球俱乐部和俄罗斯的格罗兹尼特里克足球俱乐部。 例外情况发生于 2005 年末至 2009 年期间,当时他效力于瑞典的卡斯塔德联队、塞尔维亚的查查克足球俱乐部和俄罗斯的格罗兹尼艾卡马特足球俱乐部。 1 1 0.9771590828895569
12 Tabaci 河是罗马尼亚 Leurda 河的支流。 Leurda 河是罗马尼亚境内 Tabaci 河的一条支流。 0 0 0.2890690565109253
20 1993 年,他为 A 级的坎恩郡美洲狮队和 AA 级的波特兰海狗队效力。 1993 年,他为 A 级球队波特兰海狗队和 AA 级球队凯恩县美洲狮队效力。 0 0 0.3313605785369873
26 Winarsky 是 IEEE、Phi Beta Kappa、ACM 和 Sigma Xi 的成员。 温那斯基是 ACM、IEEE、Phi Beta Kappa 和 Sigma Xi 的成员。 1 1 0.8842104077339172
27 1938 年,他成为英埃苏丹的政府人类学家,并领导对努巴的实地考察工作。 1938 年,他成为英埃苏丹政府的人类学家,并与努巴一起从事野外工作。 0 0 0.7835341095924377
28 比利·比利·贝特森出现在 2008 年末至 2009 年初出版的前四期《黑亚当》中。 黑亚当出现在 2008 年末至 2009 年初出版的前四期《比利·贝特森》中。 0 0 0.5175007581710815
32 利用太阳能满足此项要求的方法是在常规动力飞机上使用太阳能板。 利用太阳能满足此项要求的方法是在常规动力飞机上使用太阳能板。 1 1 1.0000001192092896
35 在调查进行期间,警察还质询了歌手梨美·托米和演员卡薇雅·马德哈万,两人均为西迪基及其妻子迪利普的好友。 作为持续进行的调查的一部分,警察还询问了歌手 Rimi Tomy 和演员 Kavya Madhavan,二人都是西迪基和他的妻子迪利普的好友。 1 0 0.6513717174530029
36 它们被稀疏的现场管弦乐覆盖,并由模糊不清和几乎故意平淡的声音构成。 它们被毫无音色的现场播放的音乐覆盖,并且由模糊不清、几乎故意稀疏的管弦乐声音构成。 0 0 0.7601726651191711
38 霍利在音乐上受到了埃尔顿·约翰的影响。 霍利 霍利在音乐上受到艾尔顿·约翰的影响。 1 1 0.8906985521316528
39 球队在 2 月 19 日当晚对下一场比赛的变化作出了回应。 该队在第二天(2 月 19 日)晚上的同一场比赛中应对变化。 0 0 0.6314491033554077
40 Nashua Silver Knights 队是当前夏季联赛的一部分,也是该市的大学体育队。 纳舒厄白银骑士团队加入了夏季大学联盟,是本市的现役球队。 0 0 0.5055076479911804
48 “断头台” 1949 年最后一次在西德使用,1966 年最后一次在东德使用。 Fall Beil 于 1949 年最后一次在西德使用并于 1966 年在东德使用。 1 1 0.8454433679580688
51 此运河为比利时,乃至欧洲,历史最为悠久的可通航运河之一。 这条运河是比利时(确切来说是欧洲)最古老的通航运河之一。 0 1 0.8682337403297424
52 他在 2009 年搬回了费城,现在住在纽约市。 他于 2009 年搬回费城,现居住在纽约市。 1 1 0.9858009219169617
53 1954 年 6 月 30 日,他因胃癌于俄克拉荷马州克莱摩尔病逝,而林恩·瑞格斯纪念馆则位于纽约市。 1954 年 6 月 30 日,他因胃癌于纽约市病逝。而林恩·瑞格斯纪念馆则位于俄克拉荷马州克莱摩尔。 0 0 0.004754672292619944
55 什蒂普西奇出生于德国科恩堡,在维也纳斯塔莫斯多夫度过了他的童年。 Stipsits 出生于科尔新堡,并在维也纳施塔莫斯多夫度过了他的童年。 1 0 0.7759240865707397
67 凯塔王朝从 12 世纪至 17 世纪初一直统治着前帝国及帝国时期的马里。 凯塔王朝在 12 世纪至 17 世纪统治了马里,在帝国成立前和帝国成立后都统治了。 1 0 0.5899931788444519
68 “天使之眼”是 1946 年的一首流行歌曲,由马特·丹尼斯作曲,厄尔·布伦特作词。 《Angel Eyes》是 1946 年的一首由 Earl Brent 作曲 Matt Dennis 作词的流行歌曲。 0 0 0.3959221839904785
统计结果
CosEnt acc: 0.7695 bestthreshold:0.7800
Sbert acc: 0.7255 bestthreshold: 0.7900
CosEnt在训练和推理任务一致后,相比较Sbert训练分类任务,推理文本相似不一致的方案,提升幅度巨大,4.4个百分点,文本向量余弦值区分度更加明显。
可以看出在paws_x数据集上,苏神的CoSENT生成的句向量在文本相似度任务上确实是比Sbert更加优秀!
参考文章