下面基于上篇文章使用到的 Chinese-medical-dialogue-data
中文医疗对话数据作为知识内容进行实验。
本篇实验使用 ES
版本为:7.14.0
二、Chinese-medical-dialogue-data 数据集
GitHub
地址如下:
数据分了 6
个科目类型:
数据格式如下所示:
其中 ask
为病症的问题描述,answer
为病症的回答。
由于数据较多,本次实验仅使用 IM_内科
数据的前 5000
条数据进行测试。
三、Embedding 模型
Embedding
模型使用开源的 chinese-roberta-wwm-ext-large
,该模型输出为 1024
维。
huggingface
地址:
基本使用如下:
from transformers import BertTokenizer, BertModel
import torch
模型下载的地址
model_name = ‘D:\AIGC\model\chinese-roberta-wwm-ext-large’
def embeddings(docs, max_length=300):
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name)
对文本进行分词、编码和填充
input_ids = []
attention_masks = []
for doc in docs:
encoded_dict = tokenizer.encode_plus(
doc,
add_special_tokens=True,
max_length=max_length,
padding=‘max_length’,
truncation=True,
return_attention_mask=True,
return_tensors=‘pt’
)
input_ids.append(encoded_dict[‘input_ids’])
attention_masks.append(encoded_dict[‘attention_mask’])
input_ids = torch.cat(input_ids, dim=0)
attention_masks = torch.cat(attention_masks, dim=0)
前向传播
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_masks)
提取最后一层的CLS向量作为文本表示
last_hidden_state = outputs.last_hidden_state
cls_embeddings = last_hidden_state[:, 0, :]
return cls_embeddings</