利用sentence bert 实现语义向量搜索

目录

基于pytorch的中文语言模型预训练:https://github.com/zhusleep/pytorch_chinese_lm_pretrain/tree/master

sentence_emb.py

search_faiss_robert768.py

faiss_index.py

gen_vec_save2_faiss.py


基于pytorch的中文语言模型预训练:https://github.com/zhusleep/pytorch_chinese_lm_pretrain/tree/master

sentence_emb.py

#from transformers import BertTokenizer, BertModel
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#
## First we initialize our model and tokenizer:
#tokenizer = BertTokenizer.from_pretrained('./result')
#model = BertModel.from_pretrained('./result')


def split_batch(init_list, batch_size):
    groups = zip(*(iter(init_list),) * batch_size)
    end_list = [list(i) for i in groups]
    count = len(init_list) % batch_size
    end_list.append(init_list[-count:]) if count != 0 else end_list
    return end_list



"""
param: sentence list
return: embeddings
"""
def encode(sentences, tokenizer, model):
    tokens = {'input_ids': [], 'attention_mask': []}
    data_num = len(sentences)

    for sentence in sentences:
        # 编码每个句子并添加到字典
        new_tokens = tokenizer.encode_plus(str(sentence), max_length=128,
                                           truncation=True, padding='max_length',
                                           return_tensors='pt')
        tokens['input_ids'].append(new_tokens['input_ids'][0])
        tokens['attention_mask'].append(new_tokens['attention_mask'][0])

    # 将张量列表重新格式化为一个张量
    tokens['input_ids'] = torch.stack(tokens['input_ids']).to(device)
    tokens['attention_mask'] = torch.stack(tokens['attention_mask']).to(device)
    model.eval()

    # We process these tokens through our model:
    with torch.no_grad():#添加这行代码
        outputs = model(**tokens)

    # odict_keys(['last_hidden_state', 'pooler_output'])

    # The dense vector representations of our text are contained within the outputs 'last_hidden_state' tensor, which we access like so:

    embeddings = outputs[0]

    # To perform this operation, we first resize our attention_mask tensor:

    attention_mask = tokens['attention_mask']
    # attention_mask.shape

    mask = attention_mask.unsqueeze(-1).expand(embeddings.size()).float()
    # mask.shape

    # 上面的每个向量表示一个单独token的掩码现在每个token都有一个大小为768的向量,表示它的attention_mask状态。然后将两个张量相乘:

    masked_embeddings = embeddings * mask
    # masked_embeddings.shape

    # torch.Size([2, 128, 768])
    torch.Size([data_num, 128, 768])

    summed = torch.sum(masked_embeddings, 1)
    summed_mask = torch.clamp(mask.sum(1), min=1e-9)
    mean_pooled = summed / summed_mask

    # print(mean_pooled)

    # print(type(mean_pooled))
    
    return mean_pooled


#sentences = [
#    "你叫什么名字?",
#    "你的名字是什么?",
#    "你的名字是什么?",
#    "你的名字是什么?",
#    "你的名字是什么?",
#    "你的名字是什么?",
#    "你的名字是什么?",
#    "你的名字是什么?",
#    "你的名字是什么?",
#]
#sb = split_batch(sentences, 2)
#embs = []
#for batch in sb:
#	emb = encode(batch)
#	embs += emb
#
#print(embs)
#print(len(embs))




search_faiss_robert768.py

import pickle
from faiss_index import faissIndex
import pandas as pd
import numpy as np
# from sentence_transformers import SentenceTransformer
# Download model
# model = SentenceTransformer('paraphrase-MiniLM-L6-v2/')
from sentence_emb import encode


from transformers import BertTokenizer, BertModel
import torch
# First we initialize our model and tokenizer:
tokenizer = BertTokenizer.from_pretrained('./result')
model = BertModel.from_pretrained('./result').cuda()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")



# faiss_index_path = "faiss_index384.pkl"
faiss_index_path = "faiss_index_robert.pkl"

symptom_name_df = pd.read_csv("col2.csv")

# 从本地加载faiss_index模型
def load_faiss_index(var_faiss_model_path):
    # 从本地加载faiss_index模型
    # with open('strategy/semantic_recall/model/tt.txt', 'r') as f:
    #     print(f.readlines())
    with open(var_faiss_model_path, mode='rb', errors=None) as fr:
        index = pickle.load(fr, encoding='ASCII', errors='ASCII')
        return index


def symptom_name_recall(symptom_name):
    # 将参数中当前的文本编码成向量
    sentence = []
    sentence.append(symptom_name)
    # qyery_emb = model.encode(sentence)
    qyery_emb = encode(sentence,tokenizer,model)
    # 去faiss中检索相近的faiss索引
    # 加载faiss
    loaded_faiss_index = load_faiss_index(faiss_index_path)
    # 寻找最近k个物料
    # R, D, I = loaded_faiss_index.search_items(qyery_emb.reshape([-1, 384]), k=10, n_probe=5)
    R, D, I = loaded_faiss_index.search_items(np.array(qyery_emb.reshape([-1, 768]).cpu()), k=10, n_probe=5)
    # 从faiss库中检索的物料ID进行转换
    result = []
    for id_list in R:
        for item in id_list:
            result.append(item)
    symptom_name_list = symptom_name_df[symptom_name_df['index'].isin(result)]['symptom_name'].to_list()

    # 从相似度检索的结果中,去除自己
    if symptom_name in symptom_name_list:
        symptom_name_list.remove(symptom_name)

    print(symptom_name + ' 的相近的词:' + str(symptom_name_list))

word_lsit = ['头痛','恶心吧吐','期饮酒','出血','失眠']
for word in word_lsit:
    symptom_name_recall(word)

faiss_index.py

import faiss
import numpy as np


class faissIndex:
    def __init__(self, dim, n_centroids, metric):
        self.dim = dim
        self.n_centriods = n_centroids
        assert metric in ('INNER_PRODUCT', 'L2'), "Input metric not in 'INNER_PRODUCT' or 'L2'"
        self.metric = faiss.METRIC_INNER_PRODUCT if metric == 'INNER_PRODUCT' else faiss.METRIC_L2
        self._build_index()
        return

    def _build_index(self):
        self._quantizer = faiss.IndexFlatL2(self.dim)
        self.index = faiss.IndexIVFFlat(self._quantizer, self.dim, self.n_centriods, self.metric)
        self.is_trained = self.index.is_trained
        self.n_samples = 0  # 查询向量池中的向量个数
        self.items = np.array([])  # 向量池中向量对应的item,数量应与self.n_samples保持一致,即向量与item一一对应
        return True

    def reset_index(self, dim, n_centroids, metric):
        self.dim = dim
        self.n_centriods = n_centroids
        assert metric in ('INNER_PRODUCT', 'L2'), "Input metric not in 'INNER_PRODUCT' or 'L2'"
        self.metric = faiss.METRIC_INNER_PRODUCT if metric == 'INNER_PRODUCT' else faiss.METRIC_L2
        self._build_index()
        return

    def train(self, vectors_train):
        self.index.train(vectors_train)
        self.is_trained = self.index.is_trained
        return

    def add(self, vectors, items=None):
        if not items.empty:  # 当有输入items时,验证之前的item和vector数量是否匹配,以及当前输入
            assert len(vectors) == len(
                items), "Length of vectors ({n_vectors}) and items ({n_items}) don't match, please check your input.".format(
                n_vectors=len(vectors), n_items=len(items))
            assert self.n_samples == len(
                self.items), "Amounts of added vectors and items don't match, cannot add more items."
            self.items = np.append(self.items, items.to_numpy())
        else:
            assert len(
                self.items) == 0, "There were items added previously, please added corresponding items in this batch."
        self.index.add(vectors)
        self.n_samples += len(vectors)
        return

    def search(self, query_vector, k, n_probe=1):
        assert query_vector.shape[
                   1] == self.dim, "The dimension of query vector ({dim_vector}) doesn't match the training vector set ({dim_index})!".format(
            dim_vector=query_vector.shape[1], dim_index=self.dim)
        assert self.is_trained, "Faiss index is not trained, please train index first!"
        assert self.n_samples > 0, "Faiss index doesn't have any vector for query, please add vectors into index first!"
        self.index.nprobe = n_probe
        D, I = self.index.search(query_vector, k)
        return D, I

    # k = 30 # 对每条向量(每行)寻找最近k个物料
    # n_probe = 5 # 每次查询只查询最近邻n_probe个聚类
    def search_items(self, query_vector, k, n_probe=1):
        D, I = self.search(query_vector, k, n_probe)
        R = [self.items[i] for i in I]
        return R, D, I

gen_vec_save2_faiss.py

"""
# 训练语义向量并保存在faiss中
step1: 将句子生成向量
step2: 将向量保存在faiss中
"""
import pandas as pd
import numpy as np
# from sentence_transformers import SentenceTransformer
# Download model
# model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
from sentence_emb import encode
import pickle
from faiss_index import faissIndex
from tqdm import tqdm

faiss_index_path = "faiss_index_robert.pkl"

from transformers import BertTokenizer, BertModel
import torch

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# First we initialize our model and tokenizer:
tokenizer = BertTokenizer.from_pretrained('./result')
model = BertModel.from_pretrained('./result').cuda()


# ====================== 创建faiss index并进行训练 ======================
# 创建faiss index并进行训练
def build_faiss_index(df_resources, semantic_vector, n_centroids=5, metric='L2'):
    print("现在开始进行faiss index模型训练")
    # 构建faiss索引模型
    dim = semantic_vector.shape[1]
    print("训练数据维度:", dim)
    print("聚类中心个数:", n_centroids)
    print("向量距离指标:", metric)

    # 训练faiss索引
    index = faissIndex(dim, n_centroids, metric)
    # vectors = np.stack(df_resources['index'].values).astype('float32') # faiss只支持32位浮点数查询
    vectors = semantic_vector
    items = df_resources['index']
    index.train(vectors)
    index.add(vectors, items)
    print("faiss index模型已训练完成")
    return index


# ====================== 保存faiss ======================
# 将index按照指定的日期命名并保存至本地
def save_index(index, path):
    print("现在开始将faiss index保存至本地")
    fw = open(path, mode='wb', errors=None)
    pickle.dump(index, fw)
    fw.close()
    print("faiss_index模型已保存至本地")


def split_batch(init_list, batch_size):
    groups = zip(*(iter(init_list),) * batch_size)
    end_list = [list(i) for i in groups]
    count = len(init_list) % batch_size
    end_list.append(init_list[-count:]) if count != 0 else end_list
    return end_list


"""
# 利用sentence transfermer 生成文本向量
# 训练faiss
# 保存faiss
param: 
"""


def sentence2faiss_transfermer():
    df = pd.read_csv('col2.csv')
    train_json = df.to_dict('records')
    # 取文本将文本转化为向量
    title_list = [item['symptom_name'] for item in train_json]
    print(len(title_list))
    print("正在训练中.......")
    # title_list = title_list[:500]
    sb = split_batch(title_list, 8)

    embeddings = []
    # print(len(title_list))
    # emb = encode(title_list, tokenizer, model)
    # print(emb)
    # exit()
    for batch in tqdm(sb):
        try:
            emb = encode(batch, tokenizer, model)
            emb = np.array(emb.to("cpu"))
            for item in emb:
                embeddings.append(item)
        except Exception as e:
            print(e)
        # print(len(embeddings))
    # embeddings = np.array(embeddings)
#    print(embeddings)
#    print(len(embeddings))
    # exit()
    # embeddings = encode(title_list)
    # 创建faiss index并进行训练
    df_resources = pd.DataFrame(train_json)
    # print(embeddings.shape)
    print("==================================================")
    # emb = emb.cpu()

    # semantic_2d_array = np.array(embeddings.to("cpu"))
    # 将numpy数组转换成CUDA张量
    # semantic_2d_array= torch.tensor([item.cpu().detach().numpy() for item in semantic_2d_array]).cuda()

    print("开始build_faiss_index")
    # print(len(np.array(emb)))
    trained_index = build_faiss_index(df_resources, np.array(embeddings), n_centroids=5, metric='L2')

    print("开始save_index")
    # 保存faiss模型
    save_index(trained_index, faiss_index_path)


sentence2faiss_transfermer()

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
代码下载:完整代码,可直接运行 ;运行版本:2022a或2019b或2014a;若运行有问题,可私信博主; **仿真咨询 1 各类智能优化算法改进及应用** 生产调度、经济调度、装配线调度、充电优化、车间调度、发车优化、水库调度、三维装箱、物流选址、货位优化、公交排班优化、充电桩布局优化、车间布局优化、集装箱船配载优化、水泵组合优化、解医疗资源分配优化、设施布局优化、可视域基站和无人机选址优化 **2 机器学习和深度学习方面** 卷积神经网络(CNN)、LSTM、支持向量机(SVM)、最小二乘支持向量机(LSSVM)、极限学习机(ELM)、核极限学习机(KELM)、BP、RBF、宽度学习、DBN、RF、RBF、DELM、XGBOOST、TCN实现风电预测、光伏预测、电池寿命预测、辐射源识别、交通流预测、负荷预测、股价预测、PM2.5浓度预测、电池健康状态预测、水体光学参数反演、NLOS信号识别、地铁停车精准预测、变压器故障诊断 **3 图像处理方面** 图像识别、图像分割、图像检测、图像隐藏、图像配准、图像拼接、图像融合、图像增强、图像压缩感知 **4 路径规划方面** 旅行商问题(TSP)、车辆路径问题(VRP、MVRP、CVRP、VRPTW等)、无人机三维路径规划、无人机协同、无人机编队、机器人路径规划、栅格地图路径规划、多式联运运输问题、车辆协同无人机路径规划、天线线性阵列分布优化、车间布局优化 **5 无人机应用方面** 无人机路径规划、无人机控制、无人机编队、无人机协同、无人机任务分配 **6 无线传感器定位及布局方面** 传感器部署优化、通信协议优化、路由优化、目标定位优化、Dv-Hop定位优化、Leach协议优化、WSN覆盖优化、组播优化、RSSI定位优化 **7 信号处理方面** 信号识别、信号加密、信号去噪、信号增强、雷达信号处理、信号水印嵌入提取、肌电信号、脑电信号、信号配时优化 **8 电力系统方面** 微电网优化、无功优化、配电网重构、储能配置 **9 元胞自动机方面** 交通流 人群疏散 病毒扩散 晶体生长 **10 雷达方面** 卡尔曼滤波跟踪、航迹关联、航迹融合

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值