CasRel的Keras代码学习

这篇论文介绍了一种基于Keras实现的简单高效的三元组抽取模型,利用BERT作为基础,结合指针网络结构,用于关系型三元组的提取。模型首先提取subject,然后处理object,代码清晰展示了子节点和对象模型的构建过程。关键疑问在于sub_head维度和输入层复制的目的。
摘要由CSDN通过智能技术生成

论文:
A Novel Cascade Binary Tagging Framework for Relational Triple Extraction
是一个三元组抽取的方法,此方法简单明了,作者源码是Keras写的,值得学习。
模型结构如下:
在这里插入图片描述
下面是主要的模型代码:

def E2EModel(bert_config_path, bert_checkpoint_path, LR, num_rels):
    bert_model = load_trained_model_from_checkpoint(bert_config_path, bert_checkpoint_path, seq_len=None)
    for l in bert_model.layers:
        l.trainable = True

    tokens_in = Input(shape=(None,))
    segments_in = Input(shape=(None,))
    gold_sub_heads_in = Input(shape=(None,))
    gold_sub_tails_in = Input(shape=(None,))
    sub_head_in = Input(shape=(1,))#这个维度为什么?
    sub_tail_in = Input(shape=(1,))
    gold_obj_heads_in = Input(shape=(None, num_rels))
    gold_obj_tails_in = Input(shape=(None, num_rels))

    tokens, segments, gold_sub_heads, gold_sub_tails, sub_head, sub_tail, gold_obj_heads, gold_obj_tails = tokens_in, segments_in, gold_sub_heads_in, gold_sub_tails_in, sub_head_in, sub_tail_in, gold_obj_heads_in, gold_obj_tails_in#注意这里复制了一下,方便后面使用,意义何在?
    mask = Lambda(lambda x: K.cast(K.greater(K.expand_dims(x, 2), 0), 'float32'))(tokens)
    #通过Lambda层创建mask,就不需要再输入mask了

    tokens_feature = bert_model([tokens, segments])
    pred_sub_heads = Dense(1, activation='sigmoid')(tokens_feature)
    pred_sub_tails = Dense(1, activation='sigmoid')(tokens_feature)

    subject_model = Model([tokens_in, segments_in], [pred_sub_heads, pred_sub_tails]) #第一步提取的subject,直接使用的指针网络

    sub_head_feature = Lambda(seq_gather)([tokens_feature, sub_head])
    sub_tail_feature = Lambda(seq_gather)([tokens_feature, sub_tail])
    sub_feature = Average()([sub_head_feature, sub_tail_feature])

    tokens_feature = Add()([tokens_feature, sub_feature])#h+v
    pred_obj_heads = Dense(num_rels, activation='sigmoid')(tokens_feature) 
    pred_obj_tails = Dense(num_rels, activation='sigmoid')(tokens_feature)

    object_model = Model([tokens_in, segments_in, sub_head_in, sub_tail_in], [pred_obj_heads, pred_obj_tails]) #然后就是object,也是指针网络
    hbt_model = Model([tokens_in, segments_in, gold_sub_heads_in, gold_sub_tails_in, sub_head_in, sub_tail_in, gold_obj_heads_in, gold_obj_tails_in],
                        [pred_sub_heads, pred_sub_tails, pred_obj_heads, pred_obj_tails])#这个是整个模型

    #后面是计算各种loss
    gold_sub_heads = K.expand_dims(gold_sub_heads, 2)
    gold_sub_tails = K.expand_dims(gold_sub_tails, 2) 

    sub_heads_loss = K.binary_crossentropy(gold_sub_heads, pred_sub_heads)
    sub_heads_loss = K.sum(sub_heads_loss * mask) / K.sum(mask) # * 是对应位置相乘,最后得到一个数值
    sub_tails_loss = K.binary_crossentropy(gold_sub_tails, pred_sub_tails)
    sub_tails_loss = K.sum(sub_tails_loss * mask) / K.sum(mask)

    obj_heads_loss = K.sum(K.binary_crossentropy(gold_obj_heads, pred_obj_heads), 2, keepdims=True)
    obj_heads_loss = K.sum(obj_heads_loss * mask) / K.sum(mask)
    obj_tails_loss = K.sum(K.binary_crossentropy(gold_obj_tails, pred_obj_tails), 2, keepdims=True)
    obj_tails_loss = K.sum(obj_tails_loss * mask) / K.sum(mask)

    loss = (sub_heads_loss + sub_tails_loss) + (obj_heads_loss + obj_tails_loss)

    hbt_model.add_loss(loss)
    hbt_model.compile(optimizer=Adam(LR))
    hbt_model.summary()

    return subject_model, object_model, hbt_model#返回三种model

代码对应模型结构,简洁明了。有些地方已经注释在代码里了。
存疑:
1、sub_head_in = Input(shape=(1,))维度为什么是(1,) ???
2、对输入层都复制一遍,是后面还会用到

  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 7
    评论
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值