信息抽取(三)三元关系抽取——改良后的层叠式指针网络,让我的模型F1提升近4%(接上篇)

信息抽取(三)三元关系抽取——改良后的层叠式指针网络前言优化在验证集上的模型推理结果的SPO抽取方法不随机选择S(subject),⽽是遍历所有不同主语的标注样本构建训练集。模型优化加入对抗训练FGM总结前言基于我上一篇的博客:信息抽取(二)花了一个星期走了无数条弯路终于用TF复现了苏神的《Bert三元关系抽取模型》,我到底悟到了什么?复现后的模型在百度2019年语言竞赛三元关系抽取的数据集上F1值仅达到77%,我在博文总结了几点可以优化的方向,并实现一系列层叠式指针网络的改良。在此贴出代码和提升结
摘要由CSDN通过智能技术生成


前言

基于我上一篇的博客:信息抽取(二)花了一个星期走了无数条弯路终于用TF复现了苏神的《Bert三元关系抽取模型》,我到底悟到了什么?

复现后的模型在百度2019年语言竞赛三元关系抽取的数据集上F1值仅达到77%,我在博文总结了几点可以优化的方向,并实现一系列层叠式指针网络的改良。在此贴出代码和提升结果。


优化在验证集上的模型推理结果的SPO抽取方法

原方案是将token decode后组成结果,改良后的方案通过在token上的index返回到原文本中切割出答案,这避免了token无法识别一些特殊文字和符号亦或是空格。

def rematch_text_word(tokenizer,text,enc_context,enc_start,enc_end):
    span = [a.span()[0] for a in re.finditer(' ', text)]
    decode_list = [tokenizer.decode([i]) for i in enc_context][1:]
    start = 0
    end = 0
    len_start = 0
    for i in range(len(decode_list)):
        if i ==  enc_start - 1:
            start = len_start
        j = decode_list[i]
        if '#' in j and len(j)>1:
            j = j.replace('#','')
        if j == '[UNK]':
            j = '。'
        len_start += len(j)
        if i == enc_end - 1:
            end = len_start
            break
    for span_index in span:
        if start >= span_index:
            start += 1
            end += 1
        if end > span_index and span_index>start:
            end += 1
    return text[start:end]

不随机选择S(subject),⽽是遍历所有不同主语的标注样本构建训练集。

原方案是对于每组文本数据,仅随机抽取一个S以及其相关的PO构建成一组数据。
改良后,对于每组文本数据,分别抽取其所有不同的S以及其相关的PO组成多组数据。

尽管对不不同样本来说S是相同的,但在实验中发现,模型对于S的推理往往比PO关系优秀太多,因此S的可能过拟合来提升模型在PO上的表现是值得的。

以上两个优化方案的提升效果:F1:0.7719 —> 0.7979

def proceed_data(text_list,spo_list,p2id,id2p,tokenizer,MAX_LEN,sop_count):
    id_label = {
   }
    ct = len(text_list)
    MAX_LEN = MAX_LEN
    print(sop_count)
    input_ids = np.zeros((sop_count,MAX_LEN),dtype='int32')
    attention_mask = np.zeros((sop_count,MAX_LEN),dtype='int32')
    start_tokens = np.zeros((sop_count,MAX_LEN),dtype='int32')
    end_tokens = np.zeros((sop_count,MAX_LEN),dtype='int32')
    send_s_po = np.zeros((sop_count,2),dtype='int32')
    object_start_tokens = np.zeros((sop_count,MAX_LEN,len(p2id)),dtype='int32')
    object_end_tokens = np.zeros((sop_count,MAX_LEN,len(p2id)),dtype='int32')
    index_vaild = -1
    for k in range(ct):
        context_k = text_list[k].lower().replace(' ','')
        enc_context = tokenizer.encode(context_k,max_length
  • 1
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 8
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值