之前使用Sentence-Bert思想来做文本匹配相似的时候,忽视了其中的pooling细节。在对embedding做处理的时候,只是简简单单的做了一个均值处理。代码如下:
embedding_a = self.bert(indextokens_a,input_mask_a)[0]
embedding_a = torch.mean(embedding_a,1)
以上的简单方式并没有什么错误,其中每个序列的embedding中含有作为padding字符(0)的向量表示,可能会影响最后下游任务的效果。直接处理就把这些padding字符的向量表示给去掉。上代码:
def pooling(self,token_embeddings,input):
output_vectors = []
#attention_mask
attention_mask = input['attention_mask']
#[B,L]------>[B,L,1]------>[B,L,768],矩阵的值是0或者1
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
#这里做矩阵点积,就是对元素相乘(序列中padding字符,通过乘以0给去掉了)[B,L,768]
t = token_embeddings * input_mask_expanded
#[B,768]
sum_embeddings = torch.sum(t, 1)
# [B,768],最大值为seq_len
sum_mask = input_mask_expanded.sum(1)
#限定每个元素的最小值是1e-9,保证分母不为0
sum_mask = torch.clamp(sum_mask, min=1e-9)
#得到最后的具体embedding的每一个维度的值——元素相除
output_vectors.append(sum_embeddings / sum_mask)
output_vector = torch.cat(output_vectors, 1)
return output_vector
这样的embedding方式表cls向量应该是好一点,从论文的结论以及从我自己的经验来看。