手把手医学知识图谱搭建案例
注:大家觉得博客好的话,别忘了点赞收藏呀,本人每周都会更新关于人工智能和大数据相关的内容,内容多为原创,Python Java Scala SQL 代码,CV NLP 推荐系统等,Spark Flink Kafka Hbase Hive Flume等等~写的都是纯干货,各种顶会的论文解读,一起进步。
今天和大家分享一下医学知识图谱中三元组搭建的案例
github: https://github.com/king-yyf/CMeKG_tools
#博学谷IT学习技术支持#
文章目录
前言
知识图谱(Knowledge Graph)被运用在很多科研领域,其重要程度不言而喻,很多大厂都在致力于搭建属于自己的知识图谱,尤其是知识抽取是知识图谱的核心,今天和大家分享一下医学知识图谱中三元组搭建的案例。
一、先来看最终效果展示
搭建出来的三元组准确性还是相当不错的
再来看看最终的知识图谱效果图,通过模型有了三元组,利用Neo4j搭建知识图谱so easy~
二、核心代码
1.引入库
这里最重要是导入transformers包,因为整个预训练模型是通过huggingface去做的。还是是用bert。
代码如下(示例):
import json
import numpy as np
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel, AdamW
from itertools import cycle
import gc
import random
import time
import re
#先指定好所有的路径
class config:
batch_size = 1
max_seq_len = 256
num_p = 23
learning_rate = 1e-5
EPOCH = 2
PATH_SCHEMA = "/CMeKG/predicate.json"
PATH_TRAIN = '/CMeKG/train_example.json'
PATH_BERT = "/CMeKG/model/medical_re"
PATH_MODEL = "/CMeKG/model/medical_re/model_re.pkl"
PATH_SAVE = '/CMeKG/model/save'
tokenizer = BertTokenizer.from_pretrained("/CMeKG/model/medical_re/vocab.txt")
id2predicate = {}
predicate2id = {}
2.训练数据
其中PATH_TRAIN = '/CMeKG/train_example.json’是训练数据
{
“text”: “12小时尿沉渣计数的相关疾病:单纯型尿路感染,妊娠合并急性膀胱炎,慢性肾炎,狼疮性肾炎,急性膀胱炎12小时尿沉渣计数的相关症状是高血压,男子性功能障碍,蛋白尿,血尿,水肿,排尿困难及尿潴留,尿频伴尿急和尿痛”,
“spo_list”: [
[
“12小时尿沉渣计数”,
“相关疾病”,
“单纯型尿路感染”
]
其中text是原文,spo是三元组,s代表主体,p代表关系,s代表客体。
3.数据预处理
这里我都已经写好注视了,这个一个数据预处理逻辑
代码如下(示例):
class IterableDataset(torch.utils.data.IterableDataset):
def __init__(self, data, random):
super(IterableDataset).__init__()
self.data = data
self.random = random
self.tokenizer = config.tokenizer
def __len__(self):
return len(self.data)
def search(self, sequence, pattern):
n = len(pattern)
for i in range(len(sequence)):
if sequence[i:i + n] == pattern:
return i
return -1
def process_data(self):
idxs = list(range(len(self.data)))
if self.random:
np.random.shuffle(idxs)
batch_size = config.batch_size
max_seq_len = config.max_seq_len
num_p = config.num_p
batch_token_ids = np.zeros((batch_size, max_seq_len), dtype=np.int)
batch_mask_ids = np.zeros((batch_size, max_seq_len), dtype=np.int)
batch_segment_ids = np.zeros((batch_size, max_seq_len), dtype=np.int)
batch_subject_ids = np.zeros((batch_size, 2), dtype=np.int)
batch_subject_labels = np.zeros((batch_size, max_seq_len, 2), dtype=np.int)
batch_object_labels = np.zeros((batch_size, max_seq_len, num_p, 2), dtype=np.int)
batch_i = 0
for i in idxs:
text = self.data[i]['text']
batch_token_ids[batch_i, :] = self.tokenizer.encode(text, max_length=max_seq_len, pad_to_max_length=True,
add_special_tokens=True)
batch_mask_ids[batch_i, :len(text) + 2] = 1#对pad出来的设置成0
spo_list = self.data[i]['spo_list']
idx = np.random.randint(0, len(spo_list), size=1)[0]#相当于每次都是随机选一个S来组成数据
s_rand = self.tokenizer.encode(spo_list[idx][0])[1:-1]#S的ID编码
s_rand_idx = self.search(list(batch_token_ids[batch_i, :]), s_rand)#S所在text的开始索引位置
batch_subject_ids[batch_i, :] = [s_rand_idx, s_rand_idx + len(s_rand) - 1]#S所在text的起始和终止索引位置
for i in range(len(spo_list)):
spo = spo_list[i]
s = self.tokenizer.encode(spo[0])[1:-1]#不要首尾特殊字符
p = config.prediction2id[spo[1]]
o = self.tokenizer.encode(spo[2])[1:-1]
s_idx = self.search(list(batch_token_ids[batch_i]), s)#S的开始位置
o_idx = self.search(list(batch_token_ids[batch_i]), o)#O的开始位置
if s_idx != -1 and o_idx != -1:#他俩都存在的话
batch_subject_labels[batch_i, s_idx, 0] = 1#到时候要预测每一个token是不是S的起始和终止位置
batch_subject_labels[batch_i, s_idx + len(s) - 1, 1] = 1
if s_idx == s_rand_idx:
batch_object_labels[batch_i, o_idx, p, 0] = 1#记录O的开始位置及S与O之间的关系
batch_object_labels[batch_i, o_idx + len(o) - 1, p, 1] = 1#记录O的结束位置及S与O之间的关系
batch_i += 1
if batch_i == batch_size or i == idxs[-1]:
yield batch_token_ids, batch_mask_ids, batch_segment_ids, batch_subject_labels, batch_subject_ids, batch_object_labels
batch_token_ids[:, :] = 0
batch_mask_ids[:, :] = 0
batch_subject_ids[:, :] = 0
batch_subject_labels[:, :, :] = 0
batch_object_labels[:, :, :, :] = 0
batch_i = 0
def get_stream(self):
return cycle(self.process_data())
def __iter__(self):
return self.get_stream()
4.平平无奇的ner模型Model4s
这里和普通的ner任务完全一样,就是调用huggingface的bert预训练模型接口
代码如下(示例):
class Model4s(nn.Module):
def __init__(self, hidden_size=768):
super(Model4s, self).__init__()
self.bert = BertModel.from_pretrained(config.PATH_BERT)
for param in self.bert.parameters():
param.requires_grad = True
self.dropout = nn.Dropout(p=0.2)
self.linear = nn.Linear(in_features=hidden_size, out_features=2, bias=True)
self.sigmoid = nn.Sigmoid()
def forward(self, input_ids, input_mask, segment_ids, hidden_size=768):
hidden_states = self.bert(input_ids,
attention_mask=input_mask,
token_type_ids=segment_ids)[0] # (batch_size, sequence_length, hidden_size)
output = self.sigmoid(self.linear(self.dropout(hidden_states))).pow(2)
return output, hidden_states
5.很有创意的Model4po
这里通过上面的Model4s找到text主题后,固定主题,然后找到相应的客体和关系,相当精彩的想法。也是整个项目精彩之处。
代码如下(示例):
class Model4po(nn.Module):
def __init__(self, num_p=config.num_p, hidden_size=768):
super(Model4po, self).__init__()
self.dropout = nn.Dropout(p=0.4)
self.linear = nn.Linear(in_features=hidden_size, out_features=num_p * 2, bias=True)
self.sigmoid = nn.Sigmoid()
def forward(self, hidden_states, batch_subject_ids, input_mask):
all_s = torch.zeros((hidden_states.shape[0], hidden_states.shape[1], hidden_states.shape[2]),
dtype=torch.float32)
for b in range(hidden_states.shape[0]):
s_start = batch_subject_ids[b][0]
s_end = batch_subject_ids[b][1]
s = hidden_states[b][s_start] + hidden_states[b][s_end]#起始特征+终止特征
cue_len = torch.sum(input_mask[b])#实际长度
all_s[b, :cue_len, :] = s#将所有位置的特征设置成主体的
hidden_states += all_s#每一个位置实际特征都是 自身 + 主体
#我估计pow(4)这么大个数 是由于预测出来的结果都有些大,要降低预测值大小
output = self.sigmoid(self.linear(self.dropout(hidden_states))).pow(4)#预测每一个位置与主题的关系
return output # (batch_size, max_seq_len, num_p*2)
6.正常训练模型
写的也比较普通,把2个模型的损失加起来就行。
def train(train_data_loader, model4s, model4po, optimizer):
for epoch in range(config.EPOCH):
begin_time = time.time()
model4s.train()
model4po.train()
train_loss = 0.
for bi, batch in enumerate(train_data_loader):
if bi >= len(train_data_loader) // config.batch_size:
break
batch_token_ids, batch_mask_ids, batch_segment_ids, batch_subject_labels, batch_subject_ids, batch_object_labels = batch
batch_token_ids = torch.tensor(batch_token_ids, dtype=torch.long)
batch_mask_ids = torch.tensor(batch_mask_ids, dtype=torch.long)
batch_segment_ids = torch.tensor(batch_segment_ids, dtype=torch.long)
batch_subject_labels = torch.tensor(batch_subject_labels, dtype=torch.float)
batch_object_labels = torch.tensor(batch_object_labels, dtype=torch.float).view(config.batch_size,
config.max_seq_len,
config.num_p * 2)
batch_subject_ids = torch.tensor(batch_subject_ids, dtype=torch.int)
batch_subject_labels_pred, hidden_states = model4s(batch_token_ids, batch_mask_ids, batch_segment_ids)
loss4s = loss_fn(batch_subject_labels_pred, batch_subject_labels.to(torch.float32))
loss4s = torch.mean(loss4s, dim=2, keepdim=False) * batch_mask_ids#只计算非pad部分
loss4s = torch.sum(loss4s)
loss4s = loss4s / torch.sum(batch_mask_ids)
batch_object_labels_pred = model4po(hidden_states, batch_subject_ids, batch_mask_ids)
loss4po = loss_fn(batch_object_labels_pred, batch_object_labels.to(torch.float32))
loss4po = torch.mean(loss4po, dim=2, keepdim=False) * batch_mask_ids
loss4po = torch.sum(loss4po)
loss4po = loss4po / torch.sum(batch_mask_ids)
loss = loss4s + loss4po
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += float(loss.item())
print('batch:', bi, 'loss:', float(loss.item()))
print('final train_loss:', train_loss / len(train_data_loader) * config.batch_size, 'cost time:',
time.time() - begin_time)
del train_data_loader
gc.collect();
return {
"model4s_state_dict": model4s.state_dict(),
"model4po_state_dict": model4po.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
}
7.demo案例
输出的结果就是文章开头
import medical_re
import json
model4s, model4po = medical_re.load_model()
text = '据报道称,新冠肺炎患者经常会发热、咳嗽,少部分患者会胸闷、乏力,其病因包括: 1.自身免疫系统缺陷\n2.人传人。' # content是输入的一段文字
res = medical_re.get_triples(text, model4s, model4po)
print(json.dumps(res, ensure_ascii=False, indent=True))
总结
通过huggingface中bert预训练模型实现的一个非常不错的知识图谱搭建过程。Neo4j怎么搭建,以后有时间继续更新。