Sentence-Bert中pooling的理解

        之前使用Sentence-Bert思想来做文本匹配相似的时候,忽视了其中的pooling细节。在对embedding做处理的时候,只是简简单单的做了一个均值处理。代码如下:

embedding_a = self.bert(indextokens_a,input_mask_a)[0]
embedding_a = torch.mean(embedding_a,1)

以上的简单方式并没有什么错误,其中每个序列的embedding中含有作为padding字符(0)的向量表示,可能会影响最后下游任务的效果。直接处理就把这些padding字符的向量表示给去掉。上代码:

    def pooling(self,token_embeddings,input):
        output_vectors = []
        #attention_mask
        attention_mask = input['attention_mask']
        #[B,L]------>[B,L,1]------>[B,L,768],矩阵的值是0或者1
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        #这里做矩阵点积,就是对元素相乘(序列中padding字符,通过乘以0给去掉了)[B,L,768]
        t = token_embeddings * input_mask_expanded
        #[B,768]
        sum_embeddings = torch.sum(t, 1)

        # [B,768],最大值为seq_len
        sum_mask = input_mask_expanded.sum(1)
        #限定每个元素的最小值是1e-9,保证分母不为0
        sum_mask = torch.clamp(sum_mask, min=1e-9)
        #得到最后的具体embedding的每一个维度的值——元素相除
        output_vectors.append(sum_embeddings / sum_mask)

        output_vector = torch.cat(output_vectors, 1)

        return  output_vector

这样的embedding方式表cls向量应该是好一点,从论文的结论以及从我自己的经验来看。

UKPLab/sentence-transformers官网代码

  • 2
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值