BERT结合Faiss实现问答检索(少于50行代码)

本文主要介绍一个框架nlp-basictasks
nlp-basictasks是利用PyTorch深度学习框架所构建一个简单的库,旨在快速搭建模型完成一些基础的NLP任务,如分类、匹配、序列标注、语义相似度计算等。

下面利用该框架实现问答检索

所谓检索,就是将本地所有待检索的每一个句子编码成一个向量,将这些向量存储起来,称之为索引index。然后针对用户的问题query,将其编码为一个对应的向量query_vector,计算query_vector和索引中每一个vector的相似度,得到相似度最高的几个vector,最后返回这些vector对应的问题。
此外还需要一个变量query2id,
query2id指的是每一个问题和对应的id之间的映射,这个id就是索引中这个问题对应的向量的id,因为当我们计算完相似度最高的几个向量后,还要根据向量在索引中的id从query2id中找出对应的问题

数据集介绍

实验所用数据集是常用的中文自然语言推理数据集lcqmc,来源http://icrc.hitsz.edu.cn/Article/show/171.html

all_sentences=[]
data_folder='lcqmc/lcqmc_test.tsv'#我们只取少量的test.tsv实验
with open(data_folder) as f:
    lines=f.readlines()
    for line in lines[1:]:
        line_split=line.strip().split('\t')
        all_sentences.append(line_split[0])
        all_sentences.append(line_split[1])

导入包

import sys,os
from nlp_basictasks.webservices.sts_retrieve import RetrieveModel
import nlp_basictasks

定义路径加载模型

model_name_or_path='' #model_name_or_path指的就是你下载的BERT模型存放的路径 如:'chinese-roberta-wwm/'
save_index_path='' #这个变量代表你所要存储的索引的路径,如:'faiss_index.index'
save_query2id_path='' #这个变量代表你要存储的query2id的路径,如:'query2id.json'。
IR_model=RetrieveModel(save_index_path=save_index_path,
                       save_query2id_path=save_query2id_path,
                       encode_dim=768,
                      model_name_or_path=model_name_or_path)

建立索引

IR_model.createIndex(all_sentences)

问答检索

import time
start_time=time.time()
result=IR_model.retrieval("那个人正在玩电子游戏",topk=10)#检索回10个最相似的问题
end_time=time.time()
print("从%d个问题中检索一个问题需要%f ms"%(IR_model.index.ntotal,(end_time-start_time)*1000))

在这里插入图片描述
此外还支持动态的向索引中添加句子和删除索引中的句子,相关细节见nlp-basictasks框架做问答检索

不到50行代码即可实现问答检索,觉得好用的话还请点个star,谢谢。

  • 2
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
以下是使用BERT-ETM模型进问答代码示例: 1. 导入所需的库和模型 ```python import torch from transformers import BertTokenizer, BertForQuestionAnswering from etm.etm import ETM from etm.utils import get_device device = get_device() tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") bert_model = BertForQuestionAnswering.from_pretrained("bert-base-uncased").to(device) etm_model = ETM(num_topics=50, num_embeddings=10000, hidden_size=512, num_layers=2).to(device) etm_model.load_state_dict(torch.load("path/to/etm/model.pth")) ``` 2. 定义问答函数 ```python def answer_question(question, context): # 对上下文和问题进编码 encoded_dict = tokenizer.encode_plus(question, context, add_special_tokens=True, max_length=256, return_tensors='pt') input_ids = encoded_dict['input_ids'].to(device) attention_mask = encoded_dict['attention_mask'].to(device) # 使用BERT模型预测答案的起始和结束位置 start_scores, end_scores = bert_model(input_ids, attention_mask=attention_mask) start_index = torch.argmax(start_scores) end_index = torch.argmax(end_scores) # 根据预测的起始和结束位置提取答案 tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) answer_tokens = tokens[start_index:end_index+1] answer = tokenizer.convert_tokens_to_string(answer_tokens) # 使用ETM模型对答案进主题建模 with torch.no_grad(): embedding = etm_model.get_embedding_for_words([answer]).to(device) topic_weights = etm_model.get_topic_weights(embedding) topic_index = torch.argmax(topic_weights) # 返回答案和主题 return answer, topic_index ``` 3. 使用问答函数 ```python context = "The PyTorch library is used for building deep neural networks. It is one of the most popular open-source libraries for deep learning. PyTorch was developed by Facebook and is written in Python. It has a dynamic computational graph, which makes it easier to debug and optimize deep learning models." question = "Who developed PyTorch?" answer, topic = answer_question(question, context) print(f"Answer: {answer}") print(f"Topic index: {topic}") ``` 输出结果: ``` Answer: Facebook Topic index: 23 ``` 其中,主题索引23表示答案与主题模型中的第23个主题最相关。可以根据需要进进一步的主题分析和处理。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值