背景
1、在做大模型训练的时候,我们需要进行数据集的处理,而很多情况下我们收集到的数据集会存在重复数据,针对去重,有两种,一种是完全重复,也就是数据集里面有一个A,还有存在了着另外一个A;还有一种重复是语义上面非常相似,比如“天龙八部”当前热度1800W,“天龙八部”当前热度1810W,其实这两个数据对训练/测试来说并没有什么特别的地方,所以作为这样的数据的处理,需要进行语义级别的去重。
解决方案
利用SentenceTransformer框架来搭建一个语义搜索服务,也就是将待去重的文本内容,利用sentence embedding,然后根据向量计算余弦值,得出句子的相似度,然后利用util的semantic_search方法进行搜索
代码
import torch
from sentence_transformers import SentenceTransformer, util
def semanticTextualDeduplication_loop(lines, threshold):
embedder = SentenceTransformer("/data/dh/model/LaBSE")
corpus = []
corpus.append(lines[0])
corpus_embeddings = embedder.encode(corpus, convert_to_tensor=True)
lines.pop(0)
for line in lines:
queries = [line]
print('line: {0} and corpus.size: {1}'.format(line, len(corpus)))
query_embedding = embedder.encode(queries, convert_to_tensor=True)
if not semantic_search_exist(query_embedding, corpus_embeddings, threshold, line):
corpus.append(line)
print('before corpus_embeddings size: {0}'.format(len(corpus_embeddings)))
# 沿着指定维度拼接张量
corpus_embeddings = torch.cat((query_embedding, corpus_embeddings), dim=0)
print('end corpus_embeddings size: {0}'.format(len(corpus_embeddings)))
return corpus
def semantic_search_exist(query_embedding, corpus_embeddings, threshold, query):
hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=1)
hits = hits[0]
hits = hits[0]
score = hits['score']
print('input: {0} and output: {1}'.format(query, hits['score']))
if score > threshold:
return True
return False
corpus = [
"A man is riding a horse.",
"A woman is playing violin.",
"Two men pushed carts through the woods.",
"A man is riding a white horse on an enclosed ground.",
"A monkey is playing drums.",
"A cheetah is running behind its prey.",
]
result = semanticTextualDeduplication_loop(corpus, 0.9)
print("result:", result)