本笔记本说明了如何使用 Milvus 和 Towhee 从零开始构建问题解答引擎。Milvus 是为人工智能应用而构建的最先进的开源向量数据库,支持在数千万条目中进行近邻嵌入搜索,而 Towhee 则是一个使用 SoTA 机器学习模型为非结构化数据提供 ETL 的框架。
我们将通过问题解答程序来评估性能。此外,我们还使用 Towhee 将核心功能简化为近 10 行代码,这样您就可以开始创建自己的问题解答引擎了。
准备工作
安装依赖包
首先,我们需要安装 towhee、towhee.models 和 gradio 等依赖项。
! python -m pip install -q towhee towhee.models gradio
[notice] A new release of pip available: 22.3.1 -> 23.0
[notice] To update, run: pip install --upgrade pip
准备数据
本演示中使用了 InsuranceQA Corpus 的一个子集(1000 对问题和答案),大家可以在 Github 上下载。
! curl -L https://github.com/towhee-io/examples/releases/download/data/question_answer.csv -O
% Total % Received % Xferd Average Speed Time Time Time Current
Dload Upload Total Spent Left Speed
0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0
100 595k 100 595k 0 0 286k 0 0:00:02 0:00:02 --:--:-- 666k
question_answer.csv: 一个包含问题和回答的文件。
简单看一下:
import pandas as pd
df = pd.read_csv('question_answer.csv')
df.head()
id | question | answer | |
---|---|---|---|
0 | 0 | Is Disability Insurance Required By Law? | Not generally. There are five states that requ... |
1 | 1 | Can Creditors Take Life Insurance After ... | If the person who passed away was the one with... |
2 | 2 | Does Travelers Insurance Have Renters Ins... | One of the insurance carriers I represent is T... |
3 | 3 | Can I Drive A New Car Home Without Ins... | Most auto dealers will not let you drive the c... |
4 | 4 | Is The Cash Surrender Value Of Life Ins... | Cash surrender value comes only with Whole Lif... |
为了使用这个数据集来获取答案,我们需要先定义一个字典:
id_answer
: 一个id和对应回答的字典
id_answer = df.set_index('id')['answer'].to_dict()
创建 Milvus Collection
在开始之前,请确保您已启动 Milvus 服务。本笔记使用 milvus 2.2.10 和 pymilvus 2.2.11。
! python -m pip install -q pymilvus==2.2.11
接下来定义函数 create_milvus_collection
,以便在 Milvus 中创建使用 L2 距离度量 和 IVF_FLAT 索引 的集合。
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility
connections.connect(host='127.0.0.1', port='19530')
def create_milvus_collection(collection_name, dim):
if utility.has_collection(collection_name):
utility.drop_collection(collection_name)
fields = [
FieldSchema(name='id', dtype=DataType.VARCHAR, descrition='ids', max_length=500, is_primary=True, auto_id=False),
FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, descrition='embedding vectors', dim=dim)
]
schema = CollectionSchema(fields=fields, description='reverse image search')
collection = Collection(name=collection_name, schema=schema)
# create IVF_FLAT index for collection.
index_params = {
'metric_type':'L2',
'index_type':"IVF_FLAT",
'params':{"nlist":2048}
}
collection.create_index(field_name="embedding", index_params=index_params)
return collection
collection = create_milvus_collection('question_answer', 768)
问题解答引擎
在本节中,我们将展示如何使用 Milvus 和 Towhee 构建我们的问题解答引擎。问题解答背后的基本思想是使用 Towhee 从问题数据集生成嵌入,并将输入问题与存储在 Milvus 中的嵌入进行比较。
Towhee是一个机器学习框架,允许创建数据处理管道,它还为在 Milvus 中执行插入和查询操作提供了预定义操作符。
将问题的 embedding 加载到 Milvus 中
我们首先用dpr操作符从问题文本生成嵌入,然后将嵌入插入 Milvus。Towhee 提供了方法链式 API,用户可以用操作符组装数据处理管道。
%%time
from towhee import pipe, ops
import numpy as np
from towhee.datacollection import DataCollection
insert_pipe = (
pipe.input('id', 'question', 'answer')
.map('question', 'vec', ops.text_embedding.dpr(model_name='facebook/dpr-ctx_encoder-single-nq-base'))
.map('vec', 'vec', lambda x: x / np.linalg.norm(x, axis=0))
.map(('id', 'vec'), 'insert_status', ops.ann_insert.milvus_client(host='127.0.0.1', port='19530', collection_name='question_answer'))
.output()
)
import csv
with open('question_answer.csv', encoding='utf-8') as f:
reader = csv.reader(f)
next(reader)
for row in reader:
insert_pipe(*row)
CPU times: user 2min 37s, sys: 3min 59s, total: 6min 37s
Wall time: 1min 27s
print('Total number of inserted data is {}.'.format(collection.num_entities))
Total number of inserted data is 1000.
数据处理管道说明
以下是每行代码的详细解释:
# 获取三个输入,即问题 id、问题文本和问题答案
pipe.input('id', 'question', 'answer')
# 使用 acebook/dpr-ctx_encoder-single-nq-base 模型,在 towhee hub 中使用 dpr 运算符 生成问题嵌入向量
.map('question', 'vec', ops.text_embedding.dpr(model_name='facebook/dpr-ctx_encoder-single-nq-base'))
# 归一化嵌入向量
.map('vec', 'vec', lambda x: x / np.linalg.norm(x, axis=0))
# 将问题嵌入向量插入 Milvus
.map(('id', 'vec'), 'insert_status', ops.ann_insert.milvus_client(host='127.0.0.1', port='19530', collection_name='question_answer'))
向 Milvus 和 Towhee 提问
现在,问题数据集的嵌入已经插入 Milvus,我们可以用 Milvus 和 Towhee 提问了。同样,我们使用 Towhee 加载输入问题,计算嵌入,并将其用作 Milvus 中的查询。由于 Milvus 只输出 ID 和距离值,因此我们提供了 "id_answers "字典,以便根据 ID 和显示获取答案。
%%time
collection.load()
ans_pipe = (
pipe.input('question')
.map('question', 'vec', ops.text_embedding.dpr(model_name="facebook/dpr-ctx_encoder-single-nq-base"))
.map('vec', 'vec', lambda x: x / np.linalg.norm(x, axis=0))
.map('vec', 'res', ops.ann_search.milvus_client(host='127.0.0.1', port='19530', collection_name='question_answer', limit=1))
.map('res', 'answer', lambda x: [id_answer[int(i[0])] for i in x])
.output('question', 'answer')
)
ans = ans_pipe('Is Disability Insurance Required By Law?')
ans = DataCollection(ans)
ans.show()
question | answer |
---|---|
Is Disability Insurance Required By Law? | Not generally. There are five states that require most all employers carry short term disability insurance on their employees. T... |
CPU times: user 1.12 s, sys: 375 ms, total: 1.49 s
Wall time: 16.7 s
那么我们可以得到问题 ‘Is Disability Insurance Required By Law?’ 的答案。
ans[0]['answer']
['Not generally. There are five states that require most all employers carry short term disability insurance on their employees. These states are: California, Hawaii, New Jersey, New York, and Rhode Island. Besides this mandatory short term disability law, there is no other legislative imperative for someone to purchase or be covered by disability insurance.']
发布展示平台
我们已经出色地完成了问题解答引擎的核心功能。现在是时候建立一个界面展示了。Gradio是构建演示的绝佳工具。使用 Gradio,我们只需通过一个 chat
函数来封装数据处理管道:
import towhee
def chat(message, history):
history = history or []
ans_pipe = (
pipe.input('question')
.map('question', 'vec', ops.text_embedding.dpr(model_name="facebook/dpr-ctx_encoder-single-nq-base"))
.map('vec', 'vec', lambda x: x / np.linalg.norm(x, axis=0))
.map('vec', 'res', ops.ann_search.milvus_client(host='127.0.0.1', port='19530', collection_name='question_answer', limit=1))
.map('res', 'answer', lambda x: [id_answer[int(i[0])] for i in x])
.output('question', 'answer')
)
response = ans_pipe(message).get()[1][0]
history.append((message, response))
return history, history
import gradio
collection.load()
chatbot = gradio.Chatbot(color_map=("green", "gray"))
interface = gradio.Interface(
chat,
["text", "state"],
[chatbot, "state"],
allow_screenshot=False,
allow_flagging="never",
)
interface.launch(inline=True, share=True)
Running on local URL: http://127.0.0.1:7860
Running on public URL: https://7efbf90b-a281-48f9.gradio.live
This share link expires in 72 hours. For free permanent hosting and GPU upgrades (NEW!), check out Spaces: https://huggingface.co/spaces
原文链接
nlp/question_answering/1_build_question_answering_engine.ipynb