bert关系抽取论文源码之BERTem: Matching the Blanks: Distributional Similarity for Relation Learning

0 前言

BERTem模型是一个基于bert的预训练语言模型,BERTem模型本身并没有什么特别新奇之处,但通过本篇论文通过设计了一个新颖的预训练任务,在SemEval 2010 Task 8, KBP37, TACRED以及FewRel这四个数据集上达到了sota效果。本文将结合论文与pytorch源码对该模型进行解读。

1 模型架构

对于一个由两个实体 s 1 s_1 s1 s 2 s_2 s2 以及其上下文 x \mathtt x x 组成的关系描述 r = ( x , s 1 , s 2 ) \mathtt r=(\mathtt x ,s_1, s_2) r=(x,s1,s2) ,模型的目的是学习到一个由关系描述到关系的向量表示的映射 f θ ( r ) f_\theta(\mathtt r) fθ(r) 。本文对于关系描述的输入模式和输出的关系向量表示讨论了几种可能的选择,并最终选择了方式(f)作为模型架构,图示如下:

在这里插入图片描述

1.1 输入模式

第一种是STANDARD模式,使用bert的标准输入,不加特殊标记和向量嵌入。第二种方式是Positional embeddings ,即加入位置嵌入。第三种方式是Entity marker,即在实体两侧插入特殊标记以突出实体。本文选择的正是这种方式。在包含一个实体对的文本 x \mathtt x x 中,将特殊标识 [ E 1 s t a r t ] [E1_{start}] [E1start] [ E 1 e n d ] [E1_{end}] [E1end] [ E 2 s t a r t ] [E2_{start}] [E2start] [ E 2 e n d ] [E2_{end}] [E2end] 分别插入实体两侧,得到 x ~ \tilde\mathtt x x~

x ~ = [ x 0 . . . [ E 1 s t a r t ] x i . . . x j − 1 [ E 1 e n d ] . . . [ E 2 s t a r t ] x k . . . x l − 1 [ E 2 e n d ] ] \tilde\mathtt x = [x_0...[E1_{start}]x_i...x_{j-1}[E1_{end}]...[E2_{start}]x_k...x_{l-1}[E2_{end}]] x~=[x0...[E1start]xi...xj1[E1end]...[E2start]xk...xl1[E2end]]

1.2 输出的关系向量表示

文章讨论了三种可能的关系向量表示。第一种是直接使用CLS的向量表示,因为CLS的隐藏向量常被用做文本分类任务。第二种方式是Entity mention pooling,即对实体包含的token对应的词向量做最大池化。第三种是Entity start state,即使用两个实体前面的特殊标记 [ E 1 s t a r t ] [E1_{start}] [E1start] [ E 2 s t a r t ] [E2_{start}] [E2start] 对应的词向量,并将其进行拼接。这也是本文所选择的关系向量表示。

对于获得的向量表示,过一个全连接和softmax后使用交叉熵损失即可,图示如下:

在这里插入图片描述

由于这部分的源码如下所示:

    	
    	self.classification_layer = nn.Linear(1536, n_classes_)
        
        blankv1v2 = sequence_output[:, e1_e2_start, :]
        buffer = []
        for i in range(blankv1v2.shape[0]): # iterate batch & collect
            v1v2 = blankv1v2[i, i, :, :]
            v1v2 = torch.cat((v1v2[0], v1v2[1]))
            buffer.append(v1v2)
        del blankv1v2
        v1v2 = torch.stack([a for a in buffer], dim=0)
        del buffer
        
        if self.task == 'classification':
            classification_logits = self.classification_layer(v1v2)
            return classification_logits
            

这里的sequence_output就是BertEncoders最后输出的大小为(batch_size, sequence_len, hidden_state)的矩阵,包含了一个batch中所有文本 x ~ \tilde \mathtt x x~ 的向量表示。e1_e2_start是一个大小为(batch_size, 2)的矩阵,表示了每个上下文中两个实体的开始位置,也就是 [ E 1 s t a r t ] [E1_{start}] [E1start] [ E 2 s t a r t ] [E2_{start}] [E2start] 对应的位置。sequence_output[:, e1_e2_start, :] 表示对于batch_size个sequence_output中的元素,每个元素都会按e1_e2_start中的所有元素(也是batch_size个)依次进行索引,因此得到的blanckv1v2的大小为(batch_size, batch_size, sequence_len,hidden_state)。这也是为什么在获得 [ E 1 s t a r t ] [E1_{start}] [E1start] [ E 2 s t a r t ] [E2_{start}] [E2start] 对应的向量表示时需要 v1v2 = blankv1v2[i, i, :, :] 用两个 i 来选择相应的batch。可以看出来这里造成了不小的空间浪费,后续可以加以改进。

对于few-shot任务,只需计算候选关系向量表示和查询关系向量表示的点积,即可得到相似分数。模型图如下:

在这里插入图片描述

2 预训练:Matching the Blanks

接下来就是本篇文章的重头戏了。我们将从背景及损失函数、使用Blanks的原因,实现策略三个方面来介绍如何通过Matching the Blanks来获得关系描述 r \mathtt r r 到 关系向量表示的映射 f θ ( r ) f_\theta(\mathtt r) fθ(r)

要注意的是,在预训练时,除了 Matching the Blanks loss以外,也使用了 bert 的 masked language model loss 。

2.1 背景及损失函数

人工精标注的数据集质量很高,但往往因成本过大,而无法获得大量的训练数据。一种可行的解决办法是使用远程监督,但本文使用了一种更加简单暴力的策略。本文通过实体识别工具,构建一个训练集 D = [ ( r 0 , e 1 0 , e 2 0 ) . . . ( r N , e 1 N , e 2 N ) ] \mathcal D=[(\mathtt r^0, e_1^0, e_2^0 ) ... (\mathtt r^N, e_1^N, e_2^N)] D=[(r0,e10,e20)...(rN,e1N,e2N)] ,其中 r = ( x , s 1 , s 2 ) \mathtt r=(\mathtt x ,s_1, s_2) r=(x,s1,s2) 是一段关系描述 , e 1 , e 2 e_1, e_2 e1,e2 r \mathtt r r 中的实体, s 1 , s 2 s_1,s_2 s1,s2 是实体的下标。

文章观察到,对于两个关系描述 r = ( x , s 1 , s 2 ) , r ′ = ( x ′ , s 1 ′ , s 2 ′ ) \mathtt r=(\mathtt x, s_1, s_2) , \mathtt r'=(\mathtt x', s_1', s_2') r=(x,s1,s2),r=(x,s1,s2) ,如果 s 1 = s 1 ′ , s 2 = s 2 ′ s_1=s_1',s_2=s_2' s1=s1,s2=s2 ,那么 r , r ′ \mathtt{r,r'} r,r 很可能描述的是同一种关系。而当 r , r ′ \mathtt r, \mathtt r' r,r 表示的是相似的关系时,那么其关系向量表示的内积 f θ ( r ) ⊤ f θ ( r ) f_\theta(r) ^\top f_\theta(r) fθ(r)fθ(r) 应该尽量高。因此我们可以通过一个二分类来表示这种分布,公式如下:

p ( l = 1 ∣ r , r ′ ) = 1 1 + exp ⁡ f θ ( r ) ⊤ f θ ( r ′ ) p(l=1|\mathtt r,\mathtt r')= {1 \over {1+\exp f_\theta(\mathtt r)^\top f_\theta(\mathtt r')}} p(l=1r,r)=1+expfθ(r)fθ(r)1

这里 l = 1 l=1 l=1 表示两者表述类似的关系。

因此对于一个训练集 D = [ ( r 0 , e 1 0 , e 2 0 ) . . . ( r N , e 1 N , e 2 N ) ] \mathcal D=[(\mathtt r^0, e_1^0, e_2^0 ) ... (\mathtt r^N, e_1^N, e_2^N)] D=[(r0,e10,e20)...(rN,e1N,e2N)] ,想要学习到映射 f θ ( r ) f_\theta(r) fθ(r) ,只需要最小化损失函数即可,公式如下:

L ( D ) = − 1 ∣ D ∣ ∑ ( r , e 1 , e 2 ) ∈ D ∑ ( r ′ , e 1 ′ , e 2 ′ ) ∈ D δ e 1 , e 1 ′ δ e 2 , e 2 ′ ⋅ log ⁡ p ( l = 1 ∣ r , r ′ ) +     ( 1 − δ e 1 , e 1 ′ δ e 2 , e 2 ′ ) ⋅ log ⁡ ( 1 − p ( l = 1 ∣ r , r ′ ) ) \mathcal L(\mathcal D)= - {1\over\vert\mathcal D \vert} \sum_{(\mathtt r, e_1, e_2 )\in\mathtt D } \sum_{(\mathtt r', e_1', e_2' )\in\mathtt D} \\ \qquad \qquad \qquad\qquad\delta_{e_1,e_1'}\delta_{e_2,e_2'} \cdot \log p(l=1|\mathtt r,\mathtt r')+\\ \qquad \qquad \qquad\qquad \ \qquad \ \quad (1-\delta_{e_1,e_1'}\delta_{e_2,e_2'} )\cdot \log(1- p(l=1|\mathtt r,\mathtt r')) L(D)=D1(r,e1,e2)D(r,e1,e2)Dδe1,e1δe2,e2logp(l=1r,r)+  (1δe1,e1δe2,e2)log(1p(l=1r,r))

这里的 δ e , e ′ \delta_{e,e'} δe,e e = e ′ e=e' e=e 时为1,否则为0,即如果两个关系描述中的实体相同,那么这两者可能描述了相似的关系。

由于本小节与后面几个小节有密切联系,本部分源码将在下文中统一剖析。

2.2 使用Blanks的原因

观察损失函数 L ( D ) \mathcal L(\mathcal D) L(D) ,我们可以发现,我们用来构造训练数据 D \mathcal D D 的实体识别工具可以完美的最小化这个损失,因为 L ( D ) \mathcal L(\mathcal D) L(D) 本质上就是在关系描述中的实体相等是最大化概率 p ( l = 1 ∣ r , r ′ ) p(l=1|\mathtt r,\mathtt r') p(l=1r,r)。如果不加以改进的话,相当于对实体识别工具进行了重复学习。

为了避免模型简单的学习实体识别,文章在原有训练数据集 D \mathcal D D 的基础上构建了一个新的数据集 D ~ = [ ( r ~ 0 , e 1 0 , e 2 0 ) . . . ( r ~ N , e 1 N , e 2 N ) ] \tilde\mathcal D=[(\tilde\mathtt r^0, e_1^0, e_2^0 ) ... (\tilde\mathtt r^N, e_1^N, e_2^N)] D~=[(r~0,e10,e20)...(r~N,e1N,e2N)],对于关系描述 r \mathtt r r 中的每个实体,都有 α = 0.7 \alpha= 0.7 α=0.7 的概率替换为特殊标记 [BLANK] 。文章认为这样在最小化 L ( D ~ ) \mathcal L (\tilde \mathcal D) L(D~) 的过程中,模型要做的就不仅仅是命名实体识别,还有编码两个实体之间的语义关系。

(ps:这里也是一个我有疑问的地方 。 我们在联合实体识别和关系抽取的模型中也会使用实体识别任务来增强模型构建关系表示的能力。将实体替换为 [BLANK] 确实会使模型编码语义信息,但模型在训练时并没有采取和 bert 的 masked language model loss 任务一样的完型填空策略(让模型预测被mask掉的token),即没有在训练后使 BERTem 模型得知被替换为 [BLANK] 的实体,使得这样的预训练导致了一定的语义流失。文章虽然在实验中证明Matching the Blanks任务可以提升关系抽取的效果,但未将存粹的实体识别任务与Matching the Blanks任务作比较。这两者哪个更有效果也有待后续研究。)

本小节的源码也非常简介,如下所示:


    def put_blanks(self, D):
        blank_e1 = np.random.uniform()
        blank_e2 = np.random.uniform()
        if blank_e1 >= self.alpha:
            r, e1, e2 = D
            D = (r, "[BLANK]", e2)
        
        if blank_e2 >= self.alpha:
            r, e1, e2 = D
            D = (r, e1, "[BLANK]")
        return D

2.3 实现策略

在构建数据集 D \mathcal D D 时,本文使用了 Google Cloud Natural Language API ,而我找到的非官方实现中使用的时spacy。为了减轻有些共同出现但并未描述同一关系的实体对带来的噪声,文章只考虑在一个固定窗口大小(40)内的实体,即两个实体之间的间距若是超过窗口大小就不予考虑。

在实际中,如损失函数 L ( D ~ ) \mathcal L(\tilde\mathcal D) L(D~) 一样考虑 D \mathcal D D 中任意两个关系描述将会导致极高的复杂度,是不现实的。因此,本文使用了噪声对比估计,即考虑一个样本到底是属于目标分布(正例)还是噪声分布(负例)。本任务中,目标分布就是两个关系描述 r , r ′ r,r' r,r 之间共享两对实体,即认为很可能描述同一种关系;而噪声分布即两个关系描述之间不共享实体或者仅共享一对实体,这也是最小化损失函数的目标。对于目标分布,我们考虑所有正例,而对于数量远远多于正例的负例,文章在其中进行均匀采样,最后使得训练集中的正例和负例的样本数量基本相同。

3 源码实现详解

至此,本文的整体架构就已经非常清晰了,因为本文通篇有较强的联系,将代码放在每小节后可能会导致读者因为没看后面的内容而一头雾水,故将主体代码在这里统一剖析。下面我们来看一下这几部分的源码实现。由于篇幅问题,只选取关键部分代码,源码中还有许多巧妙的实现,在此就不一一赘述。

3.1 数据集构建

首先是数据集 D \mathcal D D 的构建,代码如下:

	
	nlp = spacy.load("en_core_web_lg")
    sents_doc = nlp(raw_text)
    ents = sents_doc.ents # get entities
    
    logger.info("Processing relation statements by entities...")
    entities_of_interest = ["PERSON", "NORP", "FAC", "ORG", "GPE", "LOC", "PRODUCT", "EVENT", \
                            "WORK_OF_ART", "LAW", "LANGUAGE"]
    length_doc = len(sents_doc)
    D = []; ents_list = []
    

这里对于实体识别,使用了spacy工具。


    for i in tqdm(range(len(ents))):
        e1 = ents[i]
        e1start = e1.start; e1end = e1.end
        if e1.label_ not in entities_of_interest:
            continue
        if re.search("[\d+]", e1.text): # entities should not contain numbers
            continue
        
        for j in range(1, len(ents) - i):
            e2 = ents[i + j]
            e2start = e2.start; e2end = e2.end
            if e2.label_ not in entities_of_interest:
                continue
            if re.search("[\d+]", e2.text): # entities should not contain numbers
                continue
            if e1.text.lower() == e2.text.lower(): # make sure e1 != e2
                continue
            
            if (1 <= (e2start - e1end) <= window_size): # check if next nearest entity within window_size
                # Find start of sentence
                punc_token = False
                start = e1start - 1
                if start > 0:
                    while not punc_token:
                        punc_token = sents_doc[start].is_punct
                        start -= 1
                        if start < 0:
                            break
                    left_r = start + 2 if start > 0 else 0
                else:
                    left_r = 0
                
                # Find end of sentence
                punc_token = False
                start = e2end
                if start < length_doc:
                    while not punc_token:
                        punc_token = sents_doc[start].is_punct
                        start += 1
                        if start == length_doc:
                            break
                    right_r = start if start < length_doc else length_doc
                else:
                    right_r = length_doc
                
                if (right_r - left_r) > window_size: # sentence should not be longer than window_size
                    continue
                
                x = [token.text for token in sents_doc[left_r:right_r]]
                
                r = (x, (e1start - left_r, e1end - left_r), (e2start - left_r, e2end - left_r))
                D.append((r, e1.text, e2.text))
                ents_list.append((e1.text, e2.text))

可以看到,通过 if (1 <= (e2start - e1end) <= window_size) ,间隔超过窗口大小40的实体将被过滤。punc_token = sents_doc[start].is_punct 表示判断当前字符是否为标点符号。通过查询实体对两侧最近的标点符号来获得文本描述 x \mathtt x x 的边界,进而获得关系描述 r \mathtt r r

3.2 噪声对比估计

3.2.1 目标与噪声分布

首先我们构建目标分布(正例)与噪声分布(负例),源码如下:

    
    self.df = pd.DataFrame(D, columns=['r','e1','e2'])
    
def __getitem__(self, idx):
    ### implements noise contrastive estimation
    ### get positive samples
    r, e1, e2 = self.df.iloc[idx]  # positive sample
    pool = self.df[((self.df['e1'] == e1) & (self.df['e2'] == e2))].index
    pos_idxs = np.random.choice(pool, \
                                size=min(int(self.batch_size / 2), len(pool)), replace=False)
    ### get negative samples
    if np.random.uniform() > 0.5:
        pool = self.df[((self.df['e1'] != e1) | (self.df['e2'] != e2))].index
        neg_idxs = np.random.choice(pool, \
                                    size=min(int(self.batch_size / 2), len(pool)), replace=False)
        Q = 1 / len(pool)

    else:
        if np.random.uniform() > 0.5:  # share e1 but not e2
            pool = self.df[((self.df['e1'] == e1) & (self.df['e2'] != e2))].index
            if len(pool) > 0:
                neg_idxs = np.random.choice(pool, \
                                            size=min(int(self.batch_size / 2), len(pool)), replace=False)
            else:
                neg_idxs = []

        else:  # share e2 but not e1
            pool = self.df[((self.df['e1'] != e1) & (self.df['e2'] == e2))].index
            if len(pool) > 0:
                neg_idxs = np.random.choice(pool, \
                                            size=min(int(self.batch_size / 2), len(pool)), replace=False)
            else:
                neg_idxs = []

        if len(neg_idxs) == 0:  # if empty, sample from all negatives
            pool = self.df[((self.df['e1'] != e1) | (self.df['e2'] != e2))].index
            neg_idxs = np.random.choice(pool, \
                                        size=min(int(self.batch_size / 2), len(pool)), replace=False)
        Q = 1 / len(pool)   

对于目标分布,我们从 3.1 小节中构建的数据集 D \mathcal D D 中选取所有与当前关系描述 r \mathtt r r 包含相同实体对的 r ′ \mathtt r' r ,并从中随机采样 b a t c h s i z e 2 \mathtt {batchsize} \over 2 2batchsize 个样本组成正例。(所以如果在预训练时显卡显存不足,可适当增加 gradient_acc_steps ,但最好不要减小batch_size)对于噪声分布也相同,即随机从共享一个实体、不共享实体的样本中选取。

3.2.2 分类器

本部分的代码与 1.2 小节中抽取关系向量表示的策略完全相同,只是在分类时略微有所差异,源码如下:


        self.activation = nn.Tanh()
        self.cls = BertOnlyMLMHead(config)
        
        blankv1v2 = sequence_output[:, e1_e2_start, :]
        buffer = []
        for i in range(blankv1v2.shape[0]): # iterate batch & collect
            v1v2 = blankv1v2[i, i, :, :]
            v1v2 = torch.cat((v1v2[0], v1v2[1]))
            buffer.append(v1v2)
        del blankv1v2
        v1v2 = torch.stack([a for a in buffer], dim=0)
        del buffer
        
        if self.task is None:
            blanks_logits = self.activation(v1v2) # self.blanks_linear(- torch.log(Q)
            lm_logits = self.cls(sequence_output)
            return blanks_logits, lm_logits

可以看到,对于 Matching the Blanks ,只是简单地对实体对拼接成的关系向量表示进行激活

3.2.3 Matching the Blanks loss

在介绍损失函数之前,先来看一些初始化参数和工具函数:


        self.lm_ignore_idx = lm_ignore_idx
        self.LM_criterion = nn.CrossEntropyLoss(ignore_index=self.lm_ignore_idx)
        self.use_logits = use_logits
        self.normalize = normalize
        
        if not self.use_logits:
            self.BCE_criterion = nn.BCELoss(reduction='mean')
        else:
            self.BCE_criterion = nn.BCEWithLogitsLoss(reduction='mean')
            
    def p_(self, f1_vec, f2_vec):
        if self.normalize:
            factor = 1/(torch.norm(f1_vec)*torch.norm(f2_vec))
        else:
            factor = 1.0
        
        if not self.use_logits:
            p = 1/(1 + torch.exp(-factor*torch.dot(f1_vec, f2_vec)))
        else:
            p = factor*torch.dot(f1_vec, f2_vec)
        return p
        

对于 Matching the Blanks loss,模型使用二元交叉熵损失函数,对于 bert 的 masked language model loss ,模型使用交叉熵损失。p_函数通过计算两个关系向量表示的内积来获得两个两个关系向量服从目标分布(描述相同关系)的概率 p ( l = 1 ∣ r , r ′ ) p(l=1|\mathtt r,\mathtt r') p(l=1r,r)

接下来是损失函数的主题部分:



    def forward(self, lm_logits, blank_logits, lm_labels, blank_labels, verbose=False):
        '''
        lm_logits: (batch_size, sequence_length, hidden_size)
        lm_labels: (batch_size, sequence_length, label_idxs)
        blank_logits: (batch_size, embeddings)
        blank_labels: (batch_size, 0 or 1)
        '''
        pos_idxs = [i for i, l in enumerate(blank_labels.squeeze().tolist()) if l == 1]
        neg_idxs = [i for i, l in enumerate(blank_labels.squeeze().tolist()) if l == 0]
        
        if len(pos_idxs) > 1:
            # positives
            pos_logits = []
            for pos1, pos2 in combinations(pos_idxs, 2):
                pos_logits.append(self.p_(blank_logits[pos1, :], blank_logits[pos2, :]))
            pos_logits = torch.stack(pos_logits, dim=0)
            pos_labels = [1.0 for _ in range(pos_logits.shape[0])]
        else:
            pos_logits, pos_labels = torch.FloatTensor([]), []
            if blank_logits.is_cuda:
                pos_logits = pos_logits.cuda()
        
        # negatives
        neg_logits = []
        for pos_idx in pos_idxs:
            for neg_idx in neg_idxs:
                neg_logits.append(self.p_(blank_logits[pos_idx, :], blank_logits[neg_idx, :]))
        neg_logits = torch.stack(neg_logits, dim=0)
        neg_labels = [0.0 for _ in range(neg_logits.shape[0])]
        
        blank_labels_ = torch.FloatTensor(pos_labels + neg_labels)
        
        if blank_logits.is_cuda:
            blank_labels_ = blank_labels_.cuda()
        
        lm_loss = self.LM_criterion(lm_logits, lm_labels)

        blank_loss = self.BCE_criterion(torch.cat([pos_logits, neg_logits], dim=0), \
                                        blank_labels_)

        if verbose:
            print("LM loss, blank_loss for last batch: %.5f, %.5f" % (lm_loss, blank_loss))
            
        total_loss = lm_loss + blank_loss
        return total_loss

模型首先将一个 batch 中的描述正例的关系向量之间的二组合,并将其计算所有二组合服从目标分布(描述相同关系)的概率 p ( l = 1 ∣ r , r ′ ) p(l=1|\mathtt r,\mathtt r') p(l=1r,r) 。对于负例也进行相同的操作。之后将 Matching the Blanks 和 masked language model 的逻辑结果过损失函数后得到两个预训练任务的损失,将其加和后得到总的损失。

4 结语

本文讨论了使用 bert 进行关系抽取几种可能的模型架构,并设计了一个新颖的预训练任务 Matching the Blanks,在四个数据集上均达到了sota。该预训练任务无需人工标注,同时可以适用于不同的模型架构,具有非常良好的鲁棒性,具有在实际中应用到更多模型中的潜力。

5 参考资料

论文地址:
https://arxiv.org/abs/1906.03158v1

非官方pytorch源码实现:
https://github.com/plkmo/BERT-Relation-Extraction

  • 9
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值