《CoSENT(一):比Sentence-BERT更有效的句向量方案》pytorch实现和中文实验

一、CoSENT简介

     目前已经有了很多比较有效的bert系列句向量方案,其中最易于实现以及数据构造比较友好并且效果也比较好的方法就是Sentence-BERT。当然Sentence-BERT的训练目标和推理任务的目标是不一致的,它的训练目标是优化一个分类任务,而推理却是计算文本的余弦相似度。那么有没有一种更好的方案来解决Sentence-BERT文本相似度训练和推理优化目标不一致的问题?苏神的博客——CoSENT(一):比Sentence-BERT更有效的句向量方案给出了答案,直接优化文本对的余弦相似度同label的差异!

      以前在做文本相似度任务的时候,自己也尝试过直接优化cos相似度和label的差异,采用mse损失函数,最后的结果往往是很差的。苏神提出的CoSENT比较好的解决了这里不一致的问题,同时也提高了文本相似度任务的准确率。CoSENT本质上就是一种新的损失函数,它统计了所有正样本对余弦值与负样本对余弦值的差值同样本标签差值的损失,模型训练的时候就不会崩掉,也能提高文本相似度任务的准确率。具体的理论需要读者自己去详细阅读苏神的博客——CoSENT(一):比Sentence-BERT更有效的句向量方案

 以上公式就是苏神博客中的截图,只需要正本对的余弦值都大于负样本对的余弦值就能保证loss很小,文本相似度就很有区分度。这里也有一点对比学习的意味在里面。

CosEnt训练阶段

CosEnt推理阶段

这样就使得训练和推理阶段一致了

二、pytorch实现

      这里的cosent_loss的实现参考了苏神的keras版本和shawroad同学的CoSENT_Pytorch版本的代码

def cosent_loss(output,y_true):
    """
    :param output: 模型的输出对于文本对队里a和队里b的cos值,[a,b,c]--[a1,b1,c1]-->[cos<a,a1>,cos<b,b1>,cos<c,c1>]
    :param y_true: 文本[a,b,c]--[a1,b1,c1]对应的有监督标签<a,a1>/<b,b1>/<c,c1>
    :return: loss

    参数  lamda 直接取20
    """
    output = output*20
    # 4. 取出负例-正例的差值
    #利用了矩阵计算的广播机制
    y_pred = output[:, None] - output[None, :]  # 这里是算出所有位置 两两之间余弦的差值
    # 矩阵中的第i行j列  表示的是第i个余弦值-第j个余弦值
    y_true = y_true[:, None] < y_true[None, :]  # 取出负例-正例的差值
    y_true = y_true.float()
    y_pred = y_pred - (1 - y_true) * 1e12  #这里之所以要这么减,是因为公式中所有的正样本对的余弦值减去负样本对的余弦值才计算损失,把不是这些部分通过exp(-inf)忽略掉
    y_pred = y_pred.view(-1)
    if torch.cuda.is_available():
        y_pred = torch.cat((torch.tensor([0]).float().cuda(), y_pred), dim=0)  # 这里加0是因为e^0 = 1相当于在log中加了1
    else:
        y_pred = torch.cat((torch.tensor([0]).float(), y_pred), dim=0)  # 这里加0是因为e^0 = 1相当于在log中加了1

    return torch.logsumexp(y_pred, dim=0)

注意output是文本对向量化后计算的余弦相似度;同时代码中的这几句需要结合公式看更好理解:

y_pred = output[:, None] - output[None, :]  # 这里是算出所有位置 两两之间余弦的差值
    # 矩阵中的第i行j列  表示的是第i个余弦值-第j个余弦值
    y_true = y_true[:, None] < y_true[None, :]  # 取出负例-正例的差值
    y_true = y_true.float()
    y_pred = y_pred - (1 - y_true) * 1e12  #这里之所以要这么减,是因为公式中所有的正样本对的余弦值减去负样本对的余弦值才计算损失,把不是这些部分通过exp(-inf)忽略掉
    

计算loss的时候公式中只是计算了所有的正样本对的余弦值-负样本对的余弦值,其余部分不参与计算。

三、中文实验

本次实验使用paws_x数据集,该数据集中的句子对是否相似都是模型比较难处理的文本。训练集4.3W,验证集2K,测试集2K。

数据如下:

​​

预训练模型chinese-bert-wwm-ext、epochs=20、batch_size=128、lr=1e-5

模型代码:

CosEnt

import torch.nn as nn
from transformers import BertConfig, BertModel
from transformers import BertPreTrainedModel
import torch
class CosEnt(BertPreTrainedModel):
    def __init__(self,config,max_len,tokenizer,device,task_type):
        super(CosEnt,self).__init__(config)
        self.max_len = max_len
        self.task_type = task_type
        self._target_device = device
        self.tokenizer = tokenizer
        self.bert = BertModel(config=config)



    def forward(self,inputs):
        input_a = inputs[0]
        input_b = inputs[1]
        output_a = self.bert(**input_a,return_dict=True, output_hidden_states=True)
        output_b = self.bert(**input_b,return_dict=True, output_hidden_states=True)
        #采用最后一层
        embedding_a = output_a.hidden_states[-1]
        embedding_b = output_b.hidden_states[-1]
        embedding_a = self.pooling(embedding_a,input_a)
        embedding_b = self.pooling(embedding_b, input_b)

        d = torch.mul(embedding_a,embedding_b)
        a_len = torch.norm(embedding_a,dim=1)
        b_len = torch.norm(embedding_b,dim=1)
        cos = torch.sum(d,dim=1)/(a_len*b_len)
        output = cos
        return output



    def pooling(self,token_embeddings,input):
        output_vectors = []
        #attention_mask
        attention_mask = input['attention_mask']
        #[B,L]------>[B,L,1]------>[B,L,768],矩阵的值是0或者1
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        #这里做矩阵点积,就是对元素相乘(序列中padding字符,通过乘以0给去掉了)[B,L,768]
        t = token_embeddings * input_mask_expanded
        #[B,768]
        sum_embeddings = torch.sum(t, 1)

        # [B,768],最大值为seq_len
        sum_mask = input_mask_expanded.sum(1)
        #限定每个元素的最小值是1e-9,保证分母不为0
        sum_mask = torch.clamp(sum_mask, min=1e-9)
        #得到最后的具体embedding的每一个维度的值——元素相除
        output_vectors.append(sum_embeddings / sum_mask)

        #列拼接
        output_vector = torch.cat(output_vectors, 1)

        return  output_vector

Sbert

import torch.nn as nn
from transformers import BertConfig, BertModel
from transformers import BertPreTrainedModel
import torch
class SentenceBert(BertPreTrainedModel):
    def __init__(self,config,max_len,tokenizer,device,task_type):
        super(SentenceBert,self).__init__(config)
        self.max_len = max_len
        self.task_type = task_type
        self._target_device = device
        self.tokenizer = tokenizer
        self.bert = BertModel(config=config)
        self.classifier = nn.Linear(3*config.hidden_size,config.num_labels)


    def forward(self,inputs):
        input_a = inputs[0]
        input_b = inputs[1]
        output_a = self.bert(**input_a,return_dict=True, output_hidden_states=True)
        output_b = self.bert(**input_b,return_dict=True, output_hidden_states=True)
        #采用最后一层
        embedding_a = output_a.hidden_states[-1]
        embedding_b = output_b.hidden_states[-1]
        embedding_a = self.pooling(embedding_a,input_a)
        embedding_b = self.pooling(embedding_b, input_b)

        if self.task_type =="classification":
            embedding_abs = torch.abs(embedding_a-embedding_b)
            vectors_concat = []
            vectors_concat.append(embedding_a)
            vectors_concat.append(embedding_b)
            vectors_concat.append(embedding_abs)
            #列拼接3个768————>3*768
            features = torch.cat(vectors_concat, 1)
            output = self.classifier(features)
        else:
            d = torch.mul(embedding_a,embedding_b)
            a_len = torch.norm(embedding_a,dim=1)
            b_len = torch.norm(embedding_b,dim=1)
            cos = torch.sum(d,dim=1)/(a_len*b_len)
            output = cos
        return output



    def pooling(self,token_embeddings,input):
        output_vectors = []
        #attention_mask
        attention_mask = input['attention_mask']
        #[B,L]------>[B,L,1]------>[B,L,768],矩阵的值是0或者1
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        #这里做矩阵点积,就是对元素相乘(序列中padding字符,通过乘以0给去掉了)[B,L,768]
        t = token_embeddings * input_mask_expanded
        #[B,768]
        sum_embeddings = torch.sum(t, 1)

        # [B,768],最大值为seq_len
        sum_mask = input_mask_expanded.sum(1)
        #限定每个元素的最小值是1e-9,保证分母不为0
        sum_mask = torch.clamp(sum_mask, min=1e-9)
        #得到最后的具体embedding的每一个维度的值——元素相除
        output_vectors.append(sum_embeddings / sum_mask)

        #列拼接
        output_vector = torch.cat(output_vectors, 1)

        return  output_vector

CosEnt训练采用cosent_loss,而不是mse_loss。Sbert训练直接采用分类任务的F.cross_entropy(output,labels);评测的时候计算文本相似度,选择最佳的阈值:

def predict(model,test_dataloader,device):
    model.task_type = "cossimlarity"
    cossimis = []
    for step, batch in enumerate(test_dataloader):
        batch = [t.to(device) for t in batch]
        inputs_a = {'input_ids': batch[0], 'attention_mask': batch[1], 'token_type_ids': batch[2]}
        inputs_b = {'input_ids': batch[3], 'attention_mask': batch[4], 'token_type_ids': batch[5]}
        inputs = []
        inputs.append(inputs_a)
        inputs.append(inputs_b)
        cossimi = model(inputs).detach().cpu().tolist()
        cossimis.extend(cossimi)

    test_df = pd.read_csv('data/paws_x/test_2k.tsv',sep='\t')

    labels = test_df['label'].values.tolist()


    bestacc = 0
    bestthreshold = 0
    thresholds = np.arange(0,1,0.01)
    bestpreds = []

    for threshold in thresholds:
        preds = [ 1 if cos >= threshold else 0 for cos in cossimis]
        correct = 0
        for i in range(len(labels)):
            if labels[i] == preds[i]:
                correct += 1

        acc = correct/len(labels)
        if acc > bestacc:
            bestacc = acc
            bestthreshold = threshold
            bestpreds = preds.copy()


    print('Sbert acc: %.4f  bestthreshold: %.4f'%(bestacc,bestthreshold))
    test_df['preds'] = bestpreds
    test_df['cossim'] = cossimis

    test_df.to_csv('pawsx_Sbert.csv',sep='\t',index=False)

最后的结果如下:

CosEnt部分测试集数据cossim指的是相似度,preds为预测标签(相似度大于阈值就为1否则就是0)

id	sentence1	sentence2	label	preds	cossim
10	2005 年末至 2009 年期间是例外,当时他效力于瑞典的卡斯塔德联队、塞尔维亚的查查克足球俱乐部和俄罗斯的格罗兹尼特里克足球俱乐部。	例外情况发生于 2005 年末至 2009 年期间,当时他效力于瑞典的卡斯塔德联队、塞尔维亚的查查克足球俱乐部和俄罗斯的格罗兹尼艾卡马特足球俱乐部。	1	1	0.9259706735610962
12	Tabaci 河是罗马尼亚 Leurda 河的支流。	Leurda 河是罗马尼亚境内 Tabaci 河的一条支流。	0	0	0.06759580224752426
20	1993 年,他为 A 级的坎恩郡美洲狮队和 AA 级的波特兰海狗队效力。	1993 年,他为 A 级球队波特兰海狗队和 AA 级球队凯恩县美洲狮队效力。	0	0	0.3313765823841095
26	Winarsky 是 IEEE、Phi Beta Kappa、ACM 和 Sigma Xi 的成员。	温那斯基是 ACM、IEEE、Phi Beta Kappa 和 Sigma Xi 的成员。	1	1	0.8139660954475403
27	1938 年,他成为英埃苏丹的政府人类学家,并领导对努巴的实地考察工作。	1938 年,他成为英埃苏丹政府的人类学家,并与努巴一起从事野外工作。	0	1	0.7984423041343689
28	比利·比利·贝特森出现在 2008 年末至 2009 年初出版的前四期《黑亚当》中。	黑亚当出现在 2008 年末至 2009 年初出版的前四期《比利·贝特森》中。	0	0	0.7027683258056641
32	利用太阳能满足此项要求的方法是在常规动力飞机上使用太阳能板。	利用太阳能满足此项要求的方法是在常规动力飞机上使用太阳能板。	1	1	0.9999999403953552
35	在调查进行期间,警察还质询了歌手梨美·托米和演员卡薇雅·马德哈万,两人均为西迪基及其妻子迪利普的好友。	作为持续进行的调查的一部分,警察还询问了歌手 Rimi Tomy 和演员 Kavya Madhavan,二人都是西迪基和他的妻子迪利普的好友。	1	0	0.43909767270088196
36	它们被稀疏的现场管弦乐覆盖,并由模糊不清和几乎故意平淡的声音构成。	它们被毫无音色的现场播放的音乐覆盖,并且由模糊不清、几乎故意稀疏的管弦乐声音构成。	0	0	0.7276858687400818
38	霍利在音乐上受到了埃尔顿·约翰的影响。	霍利 霍利在音乐上受到艾尔顿·约翰的影响。	1	1	0.9021589756011963
39	球队在 2 月 19 日当晚对下一场比赛的变化作出了回应。	该队在第二天(2 月 19 日)晚上的同一场比赛中应对变化。	0	0	0.7320891618728638
40	Nashua Silver Knights 队是当前夏季联赛的一部分,也是该市的大学体育队。	纳舒厄白银骑士团队加入了夏季大学联盟,是本市的现役球队。	0	0	0.7255858778953552
48	“断头台” 1949 年最后一次在西德使用,1966 年最后一次在东德使用。	Fall Beil 于 1949 年最后一次在西德使用并于 1966 年在东德使用。	1	1	0.9291356801986694
51	此运河为比利时,乃至欧洲,历史最为悠久的可通航运河之一。	这条运河是比利时(确切来说是欧洲)最古老的通航运河之一。	0	1	0.8468931913375854
52	他在 2009 年搬回了费城,现在住在纽约市。	他于 2009 年搬回费城,现居住在纽约市。	1	1	0.9845840930938721
53	1954 年 6 月 30 日,他因胃癌于俄克拉荷马州克莱摩尔病逝,而林恩·瑞格斯纪念馆则位于纽约市。	1954 年 6 月 30 日,他因胃癌于纽约市病逝。而林恩·瑞格斯纪念馆则位于俄克拉荷马州克莱摩尔。	0	0	0.01974690519273281
55	什蒂普西奇出生于德国科恩堡,在维也纳斯塔莫斯多夫度过了他的童年。	Stipsits 出生于科尔新堡,并在维也纳施塔莫斯多夫度过了他的童年。	1	0	0.7667046785354614
67	凯塔王朝从 12 世纪至 17 世纪初一直统治着前帝国及帝国时期的马里。	凯塔王朝在 12 世纪至 17 世纪统治了马里,在帝国成立前和帝国成立后都统治了。	1	0	0.5772057175636292
68	“天使之眼”是 1946 年的一首流行歌曲,由马特·丹尼斯作曲,厄尔·布伦特作词。	《Angel Eyes》是 1946 年的一首由 Earl Brent 作曲 Matt Dennis 作词的流行歌曲。	0	0	0.32964301109313965

Sbert 部分测试集数据cossim指的是相似度,preds为预测标签(相似度大于阈值就为1否则就是0)

id	sentence1	sentence2	label	preds	cossim
10	2005 年末至 2009 年期间是例外,当时他效力于瑞典的卡斯塔德联队、塞尔维亚的查查克足球俱乐部和俄罗斯的格罗兹尼特里克足球俱乐部。	例外情况发生于 2005 年末至 2009 年期间,当时他效力于瑞典的卡斯塔德联队、塞尔维亚的查查克足球俱乐部和俄罗斯的格罗兹尼艾卡马特足球俱乐部。	1	1	0.9771590828895569
12	Tabaci 河是罗马尼亚 Leurda 河的支流。	Leurda 河是罗马尼亚境内 Tabaci 河的一条支流。	0	0	0.2890690565109253
20	1993 年,他为 A 级的坎恩郡美洲狮队和 AA 级的波特兰海狗队效力。	1993 年,他为 A 级球队波特兰海狗队和 AA 级球队凯恩县美洲狮队效力。	0	0	0.3313605785369873
26	Winarsky 是 IEEE、Phi Beta Kappa、ACM 和 Sigma Xi 的成员。	温那斯基是 ACM、IEEE、Phi Beta Kappa 和 Sigma Xi 的成员。	1	1	0.8842104077339172
27	1938 年,他成为英埃苏丹的政府人类学家,并领导对努巴的实地考察工作。	1938 年,他成为英埃苏丹政府的人类学家,并与努巴一起从事野外工作。	0	0	0.7835341095924377
28	比利·比利·贝特森出现在 2008 年末至 2009 年初出版的前四期《黑亚当》中。	黑亚当出现在 2008 年末至 2009 年初出版的前四期《比利·贝特森》中。	0	0	0.5175007581710815
32	利用太阳能满足此项要求的方法是在常规动力飞机上使用太阳能板。	利用太阳能满足此项要求的方法是在常规动力飞机上使用太阳能板。	1	1	1.0000001192092896
35	在调查进行期间,警察还质询了歌手梨美·托米和演员卡薇雅·马德哈万,两人均为西迪基及其妻子迪利普的好友。	作为持续进行的调查的一部分,警察还询问了歌手 Rimi Tomy 和演员 Kavya Madhavan,二人都是西迪基和他的妻子迪利普的好友。	1	0	0.6513717174530029
36	它们被稀疏的现场管弦乐覆盖,并由模糊不清和几乎故意平淡的声音构成。	它们被毫无音色的现场播放的音乐覆盖,并且由模糊不清、几乎故意稀疏的管弦乐声音构成。	0	0	0.7601726651191711
38	霍利在音乐上受到了埃尔顿·约翰的影响。	霍利 霍利在音乐上受到艾尔顿·约翰的影响。	1	1	0.8906985521316528
39	球队在 2 月 19 日当晚对下一场比赛的变化作出了回应。	该队在第二天(2 月 19 日)晚上的同一场比赛中应对变化。	0	0	0.6314491033554077
40	Nashua Silver Knights 队是当前夏季联赛的一部分,也是该市的大学体育队。	纳舒厄白银骑士团队加入了夏季大学联盟,是本市的现役球队。	0	0	0.5055076479911804
48	“断头台” 1949 年最后一次在西德使用,1966 年最后一次在东德使用。	Fall Beil 于 1949 年最后一次在西德使用并于 1966 年在东德使用。	1	1	0.8454433679580688
51	此运河为比利时,乃至欧洲,历史最为悠久的可通航运河之一。	这条运河是比利时(确切来说是欧洲)最古老的通航运河之一。	0	1	0.8682337403297424
52	他在 2009 年搬回了费城,现在住在纽约市。	他于 2009 年搬回费城,现居住在纽约市。	1	1	0.9858009219169617
53	1954 年 6 月 30 日,他因胃癌于俄克拉荷马州克莱摩尔病逝,而林恩·瑞格斯纪念馆则位于纽约市。	1954 年 6 月 30 日,他因胃癌于纽约市病逝。而林恩·瑞格斯纪念馆则位于俄克拉荷马州克莱摩尔。	0	0	0.004754672292619944
55	什蒂普西奇出生于德国科恩堡,在维也纳斯塔莫斯多夫度过了他的童年。	Stipsits 出生于科尔新堡,并在维也纳施塔莫斯多夫度过了他的童年。	1	0	0.7759240865707397
67	凯塔王朝从 12 世纪至 17 世纪初一直统治着前帝国及帝国时期的马里。	凯塔王朝在 12 世纪至 17 世纪统治了马里,在帝国成立前和帝国成立后都统治了。	1	0	0.5899931788444519
68	“天使之眼”是 1946 年的一首流行歌曲,由马特·丹尼斯作曲,厄尔·布伦特作词。	《Angel Eyes》是 1946 年的一首由 Earl Brent 作曲 Matt Dennis 作词的流行歌曲。	0	0	0.3959221839904785

统计结果

CosEnt acc: 0.7695  bestthreshold:0.7800

Sbert acc: 0.7255  bestthreshold: 0.7900

CosEnt在训练和推理任务一致后,相比较Sbert训练分类任务,推理文本相似不一致的方案,提升幅度巨大,4.4个百分点,文本向量余弦值区分度更加明显。

可以看出在paws_x数据集上,苏神的CoSENT生成的句向量在文本相似度任务上确实是比Sbert更加优秀!

参考文章

CoSENT(一):比Sentence-BERT更有效的句向量方案

苏剑林kears版本cosent_loss

CoSENT_Pytorch

  • 5
    点赞
  • 39
    收藏
    觉得还不错? 一键收藏
  • 5
    评论
方案是为解决特定问题或达成特定目标而制定的一系列计划或步骤。它的作用是提供一种系统性的方法,以有效地应对挑战、优化流程或实现目标。以下是方案的主要作用: 问题解决: 方案的核心目标是解决问题。通过系统性的规划和执行,方案能够分析问题的根本原因,提供可行的解决方案,并引导实施过程,确保问题得到合理解决。 目标达成: 方案通常与明确的目标相关联,它提供了一种达成这些目标的计划。无论是企业战略、项目管理还是个人发展,方案的制定都有助于明确目标并提供达成目标的路径。 资源优化: 方案在设计时考虑了可用资源,以最大化其效用。通过明智的资源分配,方案可以在有限的资源条件下实现最大的效益,提高效率并减少浪费。 风险管理: 方案通常会对潜在的风险进行评估,并制定相应的风险管理策略。这有助于减轻潜在问题的影响,提高方案的可行性和可持续性。 决策支持: 方案提供了决策者所需的信息和数据,以便做出明智的决策。这种数据驱动的方法有助于减少不确定性,提高决策的准确性。 团队协作: 复杂的问题通常需要多个人的协同努力。方案提供了一个共同的框架,帮助团队成员理解各自的职责和任务,促进协作并确保整个团队朝着共同的目标努力。 监控与评估: 方案通常包括监控和评估的机制,以确保实施的有效性。通过定期的评估,可以及时调整方案,以适应变化的环境或新的挑战。 总体而言,方案的作用在于提供一种有序、有计划的方法,以解决问题、实现目标,并在实施过程中最大化资源利用和风险管理。 方案是为解决特定问题或达成特定目标而制定的一系列计划或步骤。它的作用是提供一种系统性的方法,以有效地应对挑战、优化流程或实现目标。以下是方案的主要作用: 问题解决: 方案的核心目标是解决问题。通过系统性的规划和执行,方案能够分析问题的根本原因,提供可行的解决方案,并引导实施过程,确保问题得到合理解决。 目标达成: 方案通常与明确的目标相关联,它提供了一种达成这些目标的计划。无论是企业战略、项目管理还是个人发展,方案的制定都有助于明确目标并提供达成目标的路径。 资源优化: 方案在设计时考虑了可用资源,以最大化其效用。通过明智的资源分配,方案可以在有限的资源条件下实现最大的效益,提高效率并减少浪费。 风险管理: 方案通常会对潜在的风险进行评估,并制定相应的风险管理策略。这有助于减轻潜在问题的影响,提高方案的可行性和可持续性。 决策支持: 方案提供了决策者所需的信息和数据,以便做出明智的决策。这种数据驱动的方法有助于减少不确定性,提高决策的准确性。 团队协作: 复杂的问题通常需要多个人的协同努力。方案提供了一个共同的框架,帮助团队成员理解各自的职责和任务,促进协作并确保整个团队朝着共同的目标努力。 监控与评估: 方案通常包括监控和评估的机制,以确保实施的有效性。通过定期的评估,可以及时调整方案,以适应变化的环境或新的挑战。 总体而言,方案的作用在于提供一种有序、有计划的方法,以解决问题、实现目标,并在实施过程中最大化资源利用和风险管理。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值