KG_ERT_model

from transformers import BertModel, BertPreTrainedModel
import torch.nn as nn
import torch
class SubjectModel(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.bert = BertModel(config)
        self.dense = nn.Linear(config.hidden_size, 2)
    
    def forward(self, input_ids, attention_mask=None):
        output = self.bert(input_ids, attention_mask=attention_mask)
        subject_out = self.dense(output[0])
        subject_out = torch.sigmoid(submject_out)
        return output[0], subject_out
    

class ObjectModel(nn.Module):
    def __init__(self, subject_model):
        super().__init__()
        self.encoder = subject_model
        self.dense_subject_position = nn.Linear(2, 768)
        self.dense_object = nn.Linear(768, 49 * 2)
    
    def forward(self, input_ids, subject_positions,attention_mask=None):
        output, subject_out = self.encoder(input_ids, attention_mask)
        subject_position = self.dense_subject_position(subject_position).unsqeeze(1)
        object_out = output + subject_position
        object_out = self.dense_object(object_out)
        object_out = torch.reshape(object_out, (object_out.shape[0],object_out.shape[1], 49, 2))
        object_out = torch.sigmoid(object_out)
        object_out = torch.pow(object_out, 4)
        return subject_out, object_out

在这里插入图片描述

import json
import os
import numpy as np
import torch
from tqdm import tqdm
from transformers import AdamW, BertTokenizerFast

'''
1. GPU检测
2. vocab加载
3. tokenizer加载

4. schema加载
5. load_data(train_data, valid_data)
6. subject/object,position定位
7. sequence_padding(与attention_mask互补)
8. data_generator(batch_size, input_ids, attention_mask,  
           		  subject_labels, subject_ids, object_labels)
9. SPO定义(用来对y_predict, y_truth比较,precision,recall计算)

10. 预训练模型加载(from_pretrained/local_load)
    subject_model--> object_model --> output
11. optim:adamW
12. loss_func = subject_loss(BCELoss) + object_loss(BCELoss) -->  model.train()
13. train:
    batch_size
        --> optim.zero_grad()
        --> datagenerator
        --> model
        --> loss
        --> backward()
        --> optim.step()
        --> step % 1000 == 0
            --> save_model
            --> with torch.no_grad()
            -->evaluate(y_predict, f1, precision, recall)
            
注意:
    1. top layer使用 sigmoid, not softmax
    2. subject, object每一个各有两个维度:start、end
    3. 网络结构:bert输出分两路,最后汇成一路,jointly training, 多分类任务
    4. 实体识别联合关系抽取:使用subject 同时预测object、relation
    5. 门限的设置,subject_start=0.6,subject_end=0.5; object_start=0.2, object_end=0.2, 而且pow(object_output, 4)
    
'''


GPU_NUM = 0
device = torch.device(f'cuda:{GPU_NUM}') if torch.cuda.is_available() else torch.device('cpu')

vocab = {}
with open('bert/vocab.txt', encoding='utf-8') as f:
    for l in f.readlines():
        vocab[len(vocab)] = l.strip()

def load_data(filename):
    with open(filename, encoding='utf-8') as f:
        json_list = json.load(f)
    return json_list

train_data = load_data('data/train.json')
valid_data = load_data('data/dev.json')
type(train_data), len(train_data), len(valid_data)
train_data[0]
tokenizer = BertTokenizerFast.from_pretrained('bert')
with open('data/schemas.json', encoding='utf-8') as f:
    json_list = json.load(f)
    id2predicate = json_list[0]
    predicate2id = json_list[1]
len(id2predicate),type(id2predicate)
 id2predicate, predicate2id
def search(pattern, sequence):
    n = len(sequence)
    for i in range(len(sequence)):
        if sequence[i: i + n] == pattern:
            return i
    return -1


def sequence_padding(inputs, length=None, padding=0,mode='post'):
    if length is None:
        length = max([len(x) for x in inputs])
    
    pad_width = [(0, 0) for _ in np.shape(inputs[0])]
    outputs = []
    
    for x in inputs:
        x = x[:length]
        if mode == 'post':
            pad_width[0] = (0, length - len(x))
        elif mode == 'pre':
            pad_width[0] = (length - len(x), 0)
        else:
            raise ValueError('"mode" argument must be "post" or "pre".')
        x = np.pad(x, pad_width, 'constant', constant_values=padding)
        outputs.append(x)
    return np.array(outputs)
def data_generator(data, batch_size=3):
    batch_input_ids, batch_attention_mask = [], []
    batch_subject_labels, batch_subject_ids, batch_object_labels = [], [], []
    texts = []
    for i, d in enumerate(data):
        text = d['text']
        texts.append(text)
        encoding = tokenizer(text=text)
        
        input_ids, attention_mask = encoding.input_ids, encoding.attention_mask
        spoes = {}
        for s, p, o in d['spo_list']:
            s_encoding = tokenizer(text=s).input_ids[1: -1]
            o_encoding = tokenizer(text=o).input_ids[1:-1]
            
            s_idx = search(s_encoding, input_ids)
            o_idx = search(o_encoding, input_ids)
            
            p = predicate2id[p]
            
            if s_idx != -1 and o_idx != -1:
                s = (s_idx, s_idx + len(s_encoding) -1)
                o = (o_idx, o_idx + len(o_encoding) -1)
                if s not in spoes:
                    spoes[s] = []
                spoes[s].append(o)
        if spoes:
            # subject_labels
            subject_labels = np.zeros((len(input_ids), 2))
            for s in spoes:
                subject_labels[s[0], 0] = 1
                subject_labels[s[1], 1] = 1
            
            start, end = np.array(list(spoes.keys())).T
            start = np.random.choice(start)
            end = end[end>=start][0]
            subject_ids = (start,end)
            
            # object 
            object_labels = np.zeros((len(input_ids), len(predicate2id),2))
            for o in spoes.get(subject_ids,[]):
                # [start/end, predicate,0/1]
                object_labels[o[0],o[2], 0] = 1
                object_label[o[1], o[2], 1] = 1
            
            batch_input_ids.append(input_ids)
            batch_attention_mask.append(attention_mask)
            batch_subject_labels.append(subject_labels)
            batch_subject_ids.append(subject_ids)
            batch_subject_labels.append(object_labels)
            
            if len(batch_subject_labels) == batch_size or i == len(data) - 1:
                batch_input_ids = sequence_padding(batch_input_ids)
                batch_attention_mask = sequence_padding(batch_attention)
                batch_subject_labels = sequence_padding(batch_subject_labels)
                batch_subject_ids = np.array(batch_subject_ids)
                batch_object_labels = sequence_padding(batch_object_labels)
                
                yield [torch.from_numpy(batch_input_ids).long(),
                       torch.from_numpy(batch_attention_mask).long(),
                       torch.from_numpy(batch_subject_labels),
                       torch.from_numpy(batch_subject_ids),
                       torch.from_numpy(batch_object_labels)
                      ]
                batch_input_ids, batch_attention_mask = [], []
                batch_subject_labels, batch_subject_ids, batch_object_labels = [], [], []
                
            
train_loader = data_generator(train_data, batch_size=8)
class SPO(tuple):
    def __init__(self, spo):
        self.spox = (spo[0], spo[1], spo[2])
    
    def __hash__(self):
        return self.spox.__hash__()
    
    def __eq__(self, spo):
        return self.spox == spo.spox
if os.path.exists('graph_model.bin'):
    print('load model')
    model = torch.load('graph_model.bin').to(device)
    subject_model = model.encoder
else:
    subject_model = SubjectModel.from_pretrained('./bert')
    subject_model.to(device)
    
    model = ObjectModel(subject_model)
    model.to(device)
optim = AdamW(model.parameters(), lr=5e-5)
loss_func = torch.nn.BCELoss()

def train_func():
    train_loss = 0
    pbar = tqdm(train_loader)
#     torch.cuda.empty_cache()
    for step, batch in enumerate(pbar):
       
        optim.zero_grad()
        input_ids = batch[0].to(device)
        attention_mask = batch[1].to(device)
        
        subject_labels = batch[2].to(device)
        subject_ids = batch[3].to(device)
        object_labels = batch[4].to(device)
        
        subject_out, object_out = model(input_ids, subject_ids.float(), attention_mask)
        subject_out = subject_out * attention_mask.unsqueeze(-1)
        # object_out==', torch.Size([8, 106, 49, 2]) 'attention_mask==', torch.Size([8, 106])
        object_out = object_out * attention_mask.unsqueeze(-1).unsqueeze(-1)
        
        subject_loss = loss_func(subject_out, subject_labels.float())
        object_loss = loss_func(object_out, object_labels.float())
        
        
        loss = subject_loss + object_loss
        train_loss += loss.item()
        loss.backward()
        optim.step()
        
        pbar.update()
        pbar.set_description(f'train loss: {loss.item()}')
        
        if step % 100 == 0 and step != 0:
            torch.save(model, 'graph_model.bin')
            with torch.no_grad():
                X, Y, Z = 1e-10, 1e-10, 1e-10
                pbar = tqdm()
                
                for data in valid_data[0: 100]:
                    spo = []
                    text = data['text']
                    spo_ori = data['spo_list']
                    en = tokenizer(text=text, return_tensors='pt')
                    _, subject_preds = model.encoder(en.input_ids.to(device),
                                                    en.attention_mask.to(device))
                    
                    subject_preds = subject_preds.cpu().data.numpy()
                    start = np.where(subject_preds[0, :, 0] > 0.6)[0]
                    end  = np.where(subject_preds[0, :, 1] > 0.5)[0]
                    
                    
                    subjects = []
                    for i in start:
                        j = end[end >= i]
                        if len(j) > 0:
                            j = j[0]
                            subjects.append((i,j))
                    if subjects:
                        for s in subjects:
                            index = en.input_ids.cpu().data.numpy().squeeze(0)[s[0]:s[1] + 1]
                            subject = ''.join([vocab[i] for i in index])
                            _, object_preds = model(en.input_ids.to(device), 
                                                   torch.from_numpy(np.arrays[s]).float().to(device),
                                                    en.attention_mask.to(device))
                            object_preds = object_preds.cpu().data.numpy()
                            for object_pred in object_preds:
                                start = np.where(object_pred[:, :, 0] > 0.2)
                                end = np.where(object_pred[:, :, 1] > 0.2)
                                for _start, predicate1 in zip(* start):
                                    for _end, predicate2 in zip(* end):
                                        if _start <= _end and predicate1 == predicate2:
                                            index = en.input_ids.cpu().data.numpy().squeeze(0)[_start: _end + 1]
                                            object = ''.join([vocab[i] for i in index])
                                            predicate = id2predicate[str(predicate1)]
                                            spo.append([subject, predicate, object])
                     
                    if spo:
                        print(spo[-1])
                    Predicts = set([SPO(_spo) for _spo in spo])
                    Truth = set([SPO(_spo) for _spo in spo_ori])
                    X += len(Predicts & Truth)
                    Y += len(Predicts)
                    Z += len(Truth)
                    
                    f1, precision, recall = 2 * X / (Y + Z), X/Y, X/Z
                    pbar.update()
                    pbar.set_description(f'f1: %.5f, precision: %.5f, recall: %.5f'%(f1, precision, recall))
                pbar.close()
                print('f1:', f1, 'precision: ', precision, 'recall: ', recall)
                
for epoch in range(100):
    train_func()
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值