sentence transformer试调
需求:测试HK与GB文档数据相似性问题,实现HK中单个数据对应GB中相似度前五的数据,形成一对前五相似度数据文档。
import pandas as pd
# 查看当前编码方式
# import sys
# print(sys.stdout.encoding)
# 修改文档编码utf-8问题,python3默认为utf-8(调试)
# import importlib,sys
# importlib.reload(sys)
# import sys
# import codecs
# sys.stdout = codecs.getwriter("utf-8")(sys.stdout.detach())
#采用 paraphrase-multilingual-mpnet-base-v2 预训练模型
from sentence_transformers import SentenceTransformer, util
model = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2')
def cos_sim_val(embedding_1,embedding_2):
cos_sim = []
# for index in range(len(sentence_2)):
cos_sim_single = util.cos_sim(embedding_1, embedding_2)
return cos_sim_single
if __name__ == '__main__':
f1 = open('HK3_7.txt', 'r', encoding='utf-8')
f2 = open('GB3_1.txt', 'r', encoding='utf-8')
line1 = f1.readlines()
line2 = f2.readlines()
f1.close()
f2.close()
sentence_1 = []
PROBLEM_KEY_1 = []
DEPT_NAME_1 = []
PROBLEM_KEY_2 = []
DEPT_NAME_2 = []
sentence_2 = []
for seq1 in line1:
PROBLEM_KEY1 = seq1.split('\t')[1]
DEPT_NAME1 = seq1.split('\t')[2]
sentence1 = seq1.split('\t')[0]
sentence_1.append(sentence1)
PROBLEM_KEY_1.append(PROBLEM_KEY1)
DEPT_NAME_1.append(DEPT_NAME1)
for seq2 in line2:
sentence2 = seq2.split('\t')[1]
PROBLEM_KEY2 = seq2.split('\t')[0]
DEPT_NAME2 = seq2.split('\t')[2]
sentence_2.append(sentence2)
PROBLEM_KEY_2.append(PROBLEM_KEY2)
DEPT_NAME_2.append(DEPT_NAME2)
print("sentence_1:", sentence_1)
print("sentence_2:", sentence_2)
#Encode all sentences
embedding_1 = model.encode(sentence_1)
embedding_2 = model.encode(sentence_2)
#Compute cosine similarity between all pairs
cos_sim = cos_sim_val(embedding_1,embedding_2)
'''多对多排序输出结果'''
realist = [['HK_PROBLEM_KEY', 'HK_PROBLEM_NMAE',"HK_DEPT_NAME",
'GB_PROBLEM_KEY', 'GB_PROBLEM_NMAE',"GB_DEPT_NAME"]]
for i in range(len(embedding_1)):
all_sentence_combinations_single = []
count = 0
for j in range(len(embedding_2)):
all_sentence_combinations_single.append([cos_sim[i][j], i, j])
# print("all_sentence_combinations_single:", all_sentence_combinations_single)
# 按最高余弦相似度得分排序列表
all_sentence_combinations = sorted(all_sentence_combinations_single, key=lambda x: x[0], reverse=True)
# print("all_sentence_combinations:", all_sentence_combinations)
count+=1
# print("Top-5 most similar pairs:")
for score, i, j in all_sentence_combinations[0:5]:
if count==len(embedding_2):
realist.append([PROBLEM_KEY_1[i],sentence_1[i],DEPT_NAME_1[i],
PROBLEM_KEY_2[j],sentence_2[j],DEPT_NAME_2[j]])
# print("{} \t {} \t {:.4f}".format(sentence_1[i], sentence_2[j], cos_sim[i][j]))
df = pd.DataFrame(realist)
df.to_excel(fr'./hk-gb_data_pmmbv_result/HK_GB_result.xlsx', index=False,
header=False, encoding='utf-8')
预训练模型选取参考依据:
参考依据:选取测试模型依据