bert-事件实体抽取-with_Predict

加入了predict 的代码,还可以提升准确率,没有解决对于 由于 bert 编码问题带来的 对应不到原句

最大为512,数据都筛选过了

import tensorflow as tf
import numpy as np
from bert import modeling
from bert import tokenization
from bert import optimization
import os


flags = tf.flags
FLAGS = flags.FLAGS

flags.DEFINE_integer('train_batch_size',6,'define the train batch size')
flags.DEFINE_integer('num_train_epochs',3,'define the num train epochs')
flags.DEFINE_float('warmup_proportion',0.1,'define the warmup proportion')
flags.DEFINE_float('learning_rate',5e-5,'the initial learning rate for adam')
flags.DEFINE_bool('is_training',True,'define weather fine-tune the bert model')
flags.DEFINE_integer('max_sentence_len',512,'define the max len of sentence')
flags.DEFINE_bool('task_train',True,'define the train task')
flags.DEFINE_bool('task_predict',True,'define the predict task')


def get_start_end_index(text,subtext):
    for i in range(len(text)):
        if text[i:i+len(subtext)] == subtext:
            return (i,i+len(subtext)-1)
    return (-1,-1)


train_data = []
with open('data/train_data.txt',encoding='UTF-8') as fp:
    strLines = fp.readlines()
    strLines = [item.strip() for item in strLines]
    strLines = [eval(item) for item in strLines]
    train_data.extend(strLines)

test_data = []
with open('data/test_data.txt',encoding='UTF-8') as fp:
    strLines = fp.readlines()
    strLines = [item.strip() for item in strLines]
    strLines = [eval(item) for item in strLines]
    test_data.extend(strLines)



# config_path = r'D:\NLP_SOUNDAI\learnTensor\package9\bert\chinese_L-12_H-768_A-12\bert_config.json'
# checkpoint_path = r'D:\NLP_SOUNDAI\learnTensor\package9\bert\chinese_L-12_H-768_A-12\bert_model.ckpt'
# dict_path = r'D:\NLP_SOUNDAI\learnTensor\package9\bert\chinese_L-12_H-768_A-12\vocab.txt'
config_path = './bert/chinese_L-12_H-768_A-12/bert_config.json'
checkpoint_path = './bert/chinese_L-12_H-768_A-12/bert_model.ckpt'
dict_path = './bert/chinese_L-12_H-768_A-12/vocab.txt'
bert_config = modeling.BertConfig.from_json_file(config_path)
tokenizer = tokenization.FullTokenizer(vocab_file=dict_path,do_lower_case=True)




def input_str_concat(inputList):
    assert len(inputList) == 2
    t, c = inputList
    newStr = '__%s__%s' % (c, t)
    newStr = newStr[:510]
    tokens = tokenizer.tokenize(newStr)
    tokens = ['[CLS]'] + tokens + ['[SEP]']
    input_ids = tokenizer.convert_tokens_to_ids(tokens)
    input_mask = [1] * len(input_ids)
    segment_ids = [0] * len(input_ids)
    return tokens, (input_ids, input_mask, segment_ids)



def sequence_padding(sequence):
    lenlist = [len(item) for item in sequence]
    maxlen = max(lenlist)
    return np.array([
        np.concatenate([item,[0]*(maxlen - len(item))]) if len(item) < maxlen else item for item in sequence
    ])


# get train data batch
def get_data_batch():
    batch_size = FLAGS.train_batch_size
    epoch = FLAGS.num_train_epochs
    for oneEpoch in range(epoch):
        num_batches = ((len(train_data) -1) // batch_size) + 1
        for i in range(num_batches):
            batch_data = train_data[i*batch_size:(i+1)*batch_size]
            yield_batch_data = {
                'input_ids':[],
                'input_mask':[],
                'segment_ids':[],
                'start_ids':[],
                'end_ids':[]
            }
            for item in batch_data:
                tokens, (input_ids, input_mask, segment_ids) = input_str_concat(item[:-1])
                target_tokens = tokenizer.tokenize(item[2])
                
                start, end = get_start_end_index(tokens, target_tokens)
                
                start_ids = [0] * len(input_ids)
                end_ids = [0] * len(input_ids)
                start_ids[start] = 1
                end_ids[end] = 1
                yield_batch_data['input_ids'].append(input_ids)
                yield_batch_data['input_mask'].append(input_mask)
                yield_batch_data['segment_ids'].append(segment_ids)
                yield_batch_data['start_ids'].append(start_ids)
                yield_batch_data['end_ids'].append(end_ids)
            yield_batch_data['input_ids'] = sequence_padding(yield_batch_data['input_ids'])
            yield_batch_data['input_mask'] = sequence_padding(yield_batch_data['input_mask'])
            yield_batch_data['segment_ids'] = sequence_padding(yield_batch_data['segment_ids'])
            yield_batch_data['start_ids'] = sequence_padding(yield_batch_data['start_ids'])
            yield_batch_data['end_ids'] = sequence_padding(yield_batch_data['end_ids'])
            yield yield_batch_data


with tf.Graph().as_default(),tf.Session() as sess:
    input_ids_p = tf.placeholder(dtype=tf.int64,shape=[None,None],name='input_ids_p')
    input_mask_p = tf.placeholder(dtype=tf.int64,shape=[None,None],name='input_mask_p')
    segment_ids_p = tf.placeholder(dtype=tf.int64,shape=[None,None],name='segment_ids_p')
    start_p = tf.placeholder(dtype=tf.int64,shape=[None,None],name='start_p')
    end_p = tf.placeholder(dtype=tf.int64,shape=[None,None],name='end_p')


    model = modeling.BertModel(config=bert_config,
                               is_training=FLAGS.is_training,
                               input_ids=input_ids_p,
                               input_mask=input_mask_p,
                               token_type_ids=segment_ids_p,
                               use_one_hot_embeddings=False)
    output_layer = model.get_sequence_output()

   
    word_dim = output_layer.get_shape().as_list()[-1]
    output_reshape = tf.reshape(output_layer,shape=[-1,word_dim],name='output_reshape')

    with tf.variable_scope('weitht_and_bias',reuse=tf.AUTO_REUSE,initializer=tf.truncated_normal_initializer(mean=0.,stddev=0.05)):
        weight_start = tf.get_variable(name='weight_start',shape=[word_dim,1])
        bias_start = tf.get_variable(name='bias_start',shape=[1])
        weight_end = tf.get_variable(name='weight_end',shape=[word_dim,1])
        bias_end = tf.get_variable(name='bias_end',shape=[1])

    with tf.name_scope('predict_start_and_end'):
        pred_start = tf.einsum('ijk,kd->ijd', output_layer, weight_start)
        pred_start = tf.nn.bias_add(pred_start, bias_start)
        pred_start = tf.squeeze(pred_start, -1)

        pred_end = tf.einsum('ijk,kd->ijd', output_layer, weight_end)
        pred_end = tf.nn.bias_add(pred_end, bias_end)
        pred_end = tf.squeeze(pred_end, -1)

        pred_start_index = tf.argmax(pred_start,axis=1)
        pred_end_index = tf.argmax(pred_end,axis=1)

    with tf.name_scope('loss'):
        loss1 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=pred_start,labels=start_p))

        cumsum_start_p = 1 - tf.cumsum(start_p, axis=1)
        cumsum_start_p_10 = cumsum_start_p * 100000
        end_p -= cumsum_start_p_10

        loss2 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=pred_end,labels=end_p))
        loss = loss1 + loss2

    with tf.name_scope('acc_predict'):
        start_acc_bool = tf.equal(tf.argmax(start_p, axis=1), tf.argmax(pred_start, axis=1))
        end_acc_bool = tf.equal(tf.argmax(end_p, axis=1), tf.argmax(pred_end, axis=1))
        start_acc = tf.reduce_mean(tf.cast(start_acc_bool, dtype=tf.float32))
        end_acc = tf.reduce_mean(tf.cast(end_acc_bool, dtype=tf.float32))
        total_acc = tf.reduce_mean(tf.cast(tf.reduce_all([start_acc_bool, end_acc_bool], axis=0), dtype=tf.float32))

    with tf.name_scope('train_op'):
        num_train_steps = int(
            len(train_data) / FLAGS.train_batch_size * FLAGS.num_train_epochs)
        num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)
        train_op = optimization.create_optimizer(
                loss, FLAGS.learning_rate, num_train_steps, num_warmup_steps, use_tpu=False)



    tvars = tf.trainable_variables()
    (assignment_map,initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars,checkpoint_path)
    tf.train.init_from_checkpoint(checkpoint_path,assignment_map)
    sess.run(tf.variables_initializer(tf.global_variables()))

    if FLAGS.task_train:
        total_steps = 0
        for yield_batch_data in get_data_batch():
            total_steps += 1
            feed_dict = {
                input_ids_p: yield_batch_data['input_ids'],
                input_mask_p: yield_batch_data['input_mask'],
                segment_ids_p: yield_batch_data['segment_ids'],
                start_p: yield_batch_data['start_ids'],
                end_p: yield_batch_data['end_ids']
            }
            fetches = [train_op, loss, start_acc, end_acc, total_acc]

            _, loss_val, start_acc_val, end_acc_val, total_acc_val = sess.run(fetches, feed_dict=feed_dict)
            print('i : %s, loss : %s, start_acc : %s, end_acc : %s, total_acc : %s' % (
            total_steps, loss_val, start_acc_val, end_acc_val, total_acc_val))
        print('train task done ...')

    if FLAGS.task_predict:
        resultList = []
        for item in test_data:
            if item[1] == '其他':
                resultList.append('NaN')
            else:
                yield_batch_data = {
                    'input_ids': [],
                    'input_mask': [],
                    'segment_ids': [],
                }

                tokens, (input_ids, input_mask, segment_ids) = input_str_concat(item)

                yield_batch_data['input_ids'].append(input_ids)
                yield_batch_data['input_mask'].append(input_mask)
                yield_batch_data['segment_ids'].append(segment_ids)

                yield_batch_data['input_ids'] = sequence_padding(yield_batch_data['input_ids'])
                yield_batch_data['input_mask'] = sequence_padding(yield_batch_data['input_mask'])
                yield_batch_data['segment_ids'] = sequence_padding(yield_batch_data['segment_ids'])

                feed_dict = {
                    input_ids_p: yield_batch_data['input_ids'],
                    input_mask_p: yield_batch_data['input_mask'],
                    segment_ids_p: yield_batch_data['segment_ids']
                }

                fetches = [pred_start_index, pred_end_index]

                start_index,end_index = sess.run(fetches,feed_dict=feed_dict)
                start = start_index[0]
                end   = end_index[0]
                oneResult = tokens[start:end+1]

                if oneResult in item[0]:
                    resultList.append(oneResult)
                else:
                    if oneResult.upper() in item[0]:
                        resultList.append(oneResult.upper())
                    else:
                        originStr = item[0]
                        originStr = originStr.upper()
                        oneResult = oneResult.upper()

                        oneResult.replace('[UNK]','').replace('#','').strip()
                        resultList.append(oneResult)

        with open('result.txt',encoding='UTF-8',mode='a') as ff:
            ff.write('\n'.join(resultList)+'\n')





 

  • 2
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值