手把手医学知识图谱搭建案例

手把手医学知识图谱搭建案例

注:大家觉得博客好的话,别忘了点赞收藏呀,本人每周都会更新关于人工智能和大数据相关的内容,内容多为原创,Python Java Scala SQL 代码,CV NLP 推荐系统等,Spark Flink Kafka Hbase Hive Flume等等~写的都是纯干货,各种顶会的论文解读,一起进步。
今天和大家分享一下医学知识图谱中三元组搭建的案例
github: https://github.com/king-yyf/CMeKG_tools
#博学谷IT学习技术支持#



前言

知识图谱(Knowledge Graph)被运用在很多科研领域,其重要程度不言而喻,很多大厂都在致力于搭建属于自己的知识图谱,尤其是知识抽取是知识图谱的核心,今天和大家分享一下医学知识图谱中三元组搭建的案例。


一、先来看最终效果展示

搭建出来的三元组准确性还是相当不错的
在这里插入图片描述
再来看看最终的知识图谱效果图,通过模型有了三元组,利用Neo4j搭建知识图谱so easy~
在这里插入图片描述

二、核心代码

1.引入库

这里最重要是导入transformers包,因为整个预训练模型是通过huggingface去做的。还是是用bert。

代码如下(示例):

import json
import numpy as np
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel, AdamW
from itertools import cycle
import gc
import random
import time
import re

#先指定好所有的路径
class config:
    batch_size = 1
    max_seq_len = 256
    num_p = 23
    learning_rate = 1e-5
    EPOCH = 2

    PATH_SCHEMA = "/CMeKG/predicate.json"
    PATH_TRAIN = '/CMeKG/train_example.json'
    PATH_BERT = "/CMeKG/model/medical_re"
    PATH_MODEL = "/CMeKG/model/medical_re/model_re.pkl"
    PATH_SAVE = '/CMeKG/model/save'
    tokenizer = BertTokenizer.from_pretrained("/CMeKG/model/medical_re/vocab.txt")

    id2predicate = {}
    predicate2id = {}

2.训练数据

其中PATH_TRAIN = '/CMeKG/train_example.json’是训练数据
{
“text”: “12小时尿沉渣计数的相关疾病:单纯型尿路感染,妊娠合并急性膀胱炎,慢性肾炎,狼疮性肾炎,急性膀胱炎12小时尿沉渣计数的相关症状是高血压,男子性功能障碍,蛋白尿,血尿,水肿,排尿困难及尿潴留,尿频伴尿急和尿痛”,
“spo_list”: [
[
“12小时尿沉渣计数”,
“相关疾病”,
“单纯型尿路感染”
]

其中text是原文,spo是三元组,s代表主体,p代表关系,s代表客体。


3.数据预处理

这里我都已经写好注视了,这个一个数据预处理逻辑

代码如下(示例):

class IterableDataset(torch.utils.data.IterableDataset):
    def __init__(self, data, random):
        super(IterableDataset).__init__()
        self.data = data
        self.random = random
        self.tokenizer = config.tokenizer

    def __len__(self):
        return len(self.data)

    def search(self, sequence, pattern):
        n = len(pattern)
        for i in range(len(sequence)):
            if sequence[i:i + n] == pattern:
                return i
        return -1

    def process_data(self):
        idxs = list(range(len(self.data)))
        if self.random:
            np.random.shuffle(idxs)
        batch_size = config.batch_size
        max_seq_len = config.max_seq_len
        num_p = config.num_p
        batch_token_ids = np.zeros((batch_size, max_seq_len), dtype=np.int)
        batch_mask_ids = np.zeros((batch_size, max_seq_len), dtype=np.int)
        batch_segment_ids = np.zeros((batch_size, max_seq_len), dtype=np.int)
        batch_subject_ids = np.zeros((batch_size, 2), dtype=np.int)
        batch_subject_labels = np.zeros((batch_size, max_seq_len, 2), dtype=np.int)
        batch_object_labels = np.zeros((batch_size, max_seq_len, num_p, 2), dtype=np.int)
        batch_i = 0
        for i in idxs:
            text = self.data[i]['text']
            batch_token_ids[batch_i, :] = self.tokenizer.encode(text, max_length=max_seq_len, pad_to_max_length=True,
                                                                add_special_tokens=True)
            batch_mask_ids[batch_i, :len(text) + 2] = 1#对pad出来的设置成0
            spo_list = self.data[i]['spo_list']
            idx = np.random.randint(0, len(spo_list), size=1)[0]#相当于每次都是随机选一个S来组成数据
            s_rand = self.tokenizer.encode(spo_list[idx][0])[1:-1]#S的ID编码
            s_rand_idx = self.search(list(batch_token_ids[batch_i, :]), s_rand)#S所在text的开始索引位置
            batch_subject_ids[batch_i, :] = [s_rand_idx, s_rand_idx + len(s_rand) - 1]#S所在text的起始和终止索引位置
            for i in range(len(spo_list)):
                spo = spo_list[i]
                s = self.tokenizer.encode(spo[0])[1:-1]#不要首尾特殊字符
                p = config.prediction2id[spo[1]]
                o = self.tokenizer.encode(spo[2])[1:-1]
                s_idx = self.search(list(batch_token_ids[batch_i]), s)#S的开始位置
                o_idx = self.search(list(batch_token_ids[batch_i]), o)#O的开始位置
                if s_idx != -1 and o_idx != -1:#他俩都存在的话
                    batch_subject_labels[batch_i, s_idx, 0] = 1#到时候要预测每一个token是不是S的起始和终止位置
                    batch_subject_labels[batch_i, s_idx + len(s) - 1, 1] = 1
                    if s_idx == s_rand_idx:
                        batch_object_labels[batch_i, o_idx, p, 0] = 1#记录O的开始位置及S与O之间的关系
                        batch_object_labels[batch_i, o_idx + len(o) - 1, p, 1] = 1#记录O的结束位置及S与O之间的关系
            batch_i += 1
            if batch_i == batch_size or i == idxs[-1]:
                yield batch_token_ids, batch_mask_ids, batch_segment_ids, batch_subject_labels, batch_subject_ids, batch_object_labels
                batch_token_ids[:, :] = 0
                batch_mask_ids[:, :] = 0
                batch_subject_ids[:, :] = 0
                batch_subject_labels[:, :, :] = 0
                batch_object_labels[:, :, :, :] = 0
                batch_i = 0

    def get_stream(self):
        return cycle(self.process_data())

    def __iter__(self):
        return self.get_stream()

4.平平无奇的ner模型Model4s

这里和普通的ner任务完全一样,就是调用huggingface的bert预训练模型接口

代码如下(示例):

class Model4s(nn.Module):
    def __init__(self, hidden_size=768):
        super(Model4s, self).__init__()
        self.bert = BertModel.from_pretrained(config.PATH_BERT)
        for param in self.bert.parameters():
            param.requires_grad = True
        self.dropout = nn.Dropout(p=0.2)
        self.linear = nn.Linear(in_features=hidden_size, out_features=2, bias=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_ids, input_mask, segment_ids, hidden_size=768):
        hidden_states = self.bert(input_ids,
                                  attention_mask=input_mask,
                                  token_type_ids=segment_ids)[0]  # (batch_size, sequence_length, hidden_size)
        output = self.sigmoid(self.linear(self.dropout(hidden_states))).pow(2)

        return output, hidden_states

5.很有创意的Model4po

这里通过上面的Model4s找到text主题后,固定主题,然后找到相应的客体和关系,相当精彩的想法。也是整个项目精彩之处。

代码如下(示例):

class Model4po(nn.Module):
    def __init__(self, num_p=config.num_p, hidden_size=768):
        super(Model4po, self).__init__()
        self.dropout = nn.Dropout(p=0.4)
        self.linear = nn.Linear(in_features=hidden_size, out_features=num_p * 2, bias=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, hidden_states, batch_subject_ids, input_mask):
        all_s = torch.zeros((hidden_states.shape[0], hidden_states.shape[1], hidden_states.shape[2]),
                            dtype=torch.float32)

        for b in range(hidden_states.shape[0]):
            s_start = batch_subject_ids[b][0]
            s_end = batch_subject_ids[b][1]
            s = hidden_states[b][s_start] + hidden_states[b][s_end]#起始特征+终止特征
            cue_len = torch.sum(input_mask[b])#实际长度
            all_s[b, :cue_len, :] = s#将所有位置的特征设置成主体的
        hidden_states += all_s#每一个位置实际特征都是 自身 + 主体
        #我估计pow(4)这么大个数 是由于预测出来的结果都有些大,要降低预测值大小
        output = self.sigmoid(self.linear(self.dropout(hidden_states))).pow(4)#预测每一个位置与主题的关系

        return output  # (batch_size, max_seq_len, num_p*2)

6.正常训练模型

写的也比较普通,把2个模型的损失加起来就行。

def train(train_data_loader, model4s, model4po, optimizer):
    for epoch in range(config.EPOCH):
        begin_time = time.time()
        model4s.train()
        model4po.train()
        train_loss = 0.
        for bi, batch in enumerate(train_data_loader):
            if bi >= len(train_data_loader) // config.batch_size:
                break
            batch_token_ids, batch_mask_ids, batch_segment_ids, batch_subject_labels, batch_subject_ids, batch_object_labels = batch
            batch_token_ids = torch.tensor(batch_token_ids, dtype=torch.long)
            batch_mask_ids = torch.tensor(batch_mask_ids, dtype=torch.long)
            batch_segment_ids = torch.tensor(batch_segment_ids, dtype=torch.long)
            batch_subject_labels = torch.tensor(batch_subject_labels, dtype=torch.float)
            batch_object_labels = torch.tensor(batch_object_labels, dtype=torch.float).view(config.batch_size,
                                                                                            config.max_seq_len,
                                                                                            config.num_p * 2)
            batch_subject_ids = torch.tensor(batch_subject_ids, dtype=torch.int)

            batch_subject_labels_pred, hidden_states = model4s(batch_token_ids, batch_mask_ids, batch_segment_ids)
            loss4s = loss_fn(batch_subject_labels_pred, batch_subject_labels.to(torch.float32))
            loss4s = torch.mean(loss4s, dim=2, keepdim=False) * batch_mask_ids#只计算非pad部分
            loss4s = torch.sum(loss4s)
            loss4s = loss4s / torch.sum(batch_mask_ids)

            batch_object_labels_pred = model4po(hidden_states, batch_subject_ids, batch_mask_ids)
            loss4po = loss_fn(batch_object_labels_pred, batch_object_labels.to(torch.float32))
            loss4po = torch.mean(loss4po, dim=2, keepdim=False) * batch_mask_ids
            loss4po = torch.sum(loss4po)
            loss4po = loss4po / torch.sum(batch_mask_ids)

            loss = loss4s + loss4po
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += float(loss.item())
            print('batch:', bi, 'loss:', float(loss.item()))

        print('final train_loss:', train_loss / len(train_data_loader) * config.batch_size, 'cost time:',
              time.time() - begin_time)

    del train_data_loader
    gc.collect();

    return {
        "model4s_state_dict": model4s.state_dict(),
        "model4po_state_dict": model4po.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
    }

7.demo案例

输出的结果就是文章开头

import medical_re
import json
model4s, model4po = medical_re.load_model()

text = '据报道称,新冠肺炎患者经常会发热、咳嗽,少部分患者会胸闷、乏力,其病因包括: 1.自身免疫系统缺陷\n2.人传人。'  # content是输入的一段文字
res = medical_re.get_triples(text, model4s, model4po)
print(json.dumps(res, ensure_ascii=False, indent=True))

总结

通过huggingface中bert预训练模型实现的一个非常不错的知识图谱搭建过程。Neo4j怎么搭建,以后有时间继续更新。

评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值