NLP-信息抽取-关系抽取-2014:基于CNN的实体关系分类器【利用神经网络进行关系抽取的开山之作】【数据集:SemEval-2010 Task 8】

《原始论文:Relation classification via convolutional deep neural network》

一、概述

1、本文idea提出原因

在这篇文章之前,关系提取主要是依靠于统计机器学习方法,他们性能得高低取决于提取特征的好坏。

特征提取又取决于现存NLP系统的输出,会导致误差在现存nlp工具中传播,比较依赖nlp工具

任务依赖关系导致训练比较复杂;

如何利用实体及位置特征,对关系识别做端对端的训练

2、摘要核心

目前关系识别严重依赖于nlp处理工具,可能导致错误累计及传导

提出一种不需要复杂预处理的关系识别方法

实验结果表明该方法是有效的, 达到the state-of-the-art的效果

3、论文成果

在这里插入图片描述

提出了CNN的网络结构解决关系端对端识别问题

  • 词汇级特征(lexcial level features)和句子级特征(sentence level features)
  • 提出了位置特征(PF,position features),来编码当前词与目标词对的相对距离
  • 融合字词向量的信息,更好的上下文提取模型
  • 利用SemEval-2010 Task 8 数据集做实验,达到了当时最好的水平

在这里插入图片描述

3、历史意义

在这里插入图片描述

二、文章代码

1、数据集

1.1 原始数据集

train_file.txt【样本1-8000】

1	"The system as described above has its greatest application in an arrayed <e1>configuration</e1> of antenna <e2>elements</e2>."
Component-Whole(e2,e1)
Comment: Not a collection: there is structure here, organisation.

2	"The <e1>child</e1> was carefully wrapped and bound into the <e2>cradle</e2> by means of a cord."
Other
Comment:

3	"The <e1>author</e1> of a keygen uses a <e2>disassembler</e2> to look at the raw assembly code."
Instrument-Agency(e2,e1)
Comment:

4	"A misty <e1>ridge</e1> uprises from the <e2>surge</e2>."
Other
Comment:

5	"The <e1>student</e1> <e2>association</e2> is the voice of the undergraduate student population of the State University of New York at Buffalo."
Member-Collection(e1,e2)
Comment:

6	"This is the sprawling <e1>complex</e1> that is Peru's largest <e2>producer</e2> of silver."
Other
Comment:

7	"The current view is that the chronic <e1>inflammation</e1> in the distal part of the stomach caused by Helicobacter pylori <e2>infection</e2> results in an increased acid production from the non-infected upper corpus region of the stomach."
Cause-Effect(e2,e1)
Comment:

8	"<e1>People</e1> have been moving back into <e2>downtown</e2>."
Entity-Destination(e1,e2)
Comment:

9	"The <e1>lawsonite</e1> was contained in a <e2>platinum crucible</e2> and the counter-weight was a plastic crucible with metal pieces."
Content-Container(e1,e2)
Comment: prototypical example

10	"The solute was placed inside a beaker and 5 mL of the <e1>solvent</e1> was pipetted into a 25 mL glass <e2>flask</e2> for each trial."
Entity-Destination(e1,e2)
Comment:
......

test_file.txt【样本8001-10717】

8001	"The most common <e1>audits</e1> were about <e2>waste</e2> and recycling."
Message-Topic(e1,e2)
Comment: Assuming an audit = an audit document.

8002	"The <e1>company</e1> fabricates plastic <e2>chairs</e2>."
Product-Producer(e2,e1)
Comment: (a) is satisfied

8003	"The school <e1>master</e1> teaches the lesson with a <e2>stick</e2>."
Instrument-Agency(e2,e1)
Comment:

8004	"The suspect dumped the dead <e1>body</e1> into a local <e2>reservoir</e2>."
Entity-Destination(e1,e2)
Comment:

8005	"Avian <e1>influenza</e1> is an infectious disease of birds caused by type A strains of the influenza <e2>virus</e2>."
Cause-Effect(e2,e1)
Comment:

8006	"The <e1>ear</e1> of the African <e2>elephant</e2> is significantly larger--measuring 183 cm by 114 cm in the bush elephant."
Component-Whole(e1,e2)
Comment:

8007	"A child is told a <e1>lie</e1> for several years by their <e2>parents</e2> before he/she realizes that a Santa Claus does not exist."
Product-Producer(e1,e2)
Comment: (a) is satisfied; negation is outside

8008	"Skype, a free software, allows a <e1>hookup</e1> of multiple computer <e2>users</e2> to join in an online conference call without incurring any telephone costs."
Member-Collection(e2,e1)
Comment:

8009	"The disgusting scene was retaliation against her brother Philip who rents the <e1>room</e1> inside this apartment <e2>house</e2> on Lombard street."
Component-Whole(e1,e2)
Comment:

8010	"This <e1>thesis</e1> defines the <e2>clinical characteristics</e2> of amyloid disease."
Message-Topic(e1,e2)
Comment: may be we could leave clinical out of e2.

1.2 处理后的数据

class processor(object):
    def __init__(self):
        pass

    def search_entity(self, sentence):
        e1 = re.findall(r'<e1>(.*)</e1>', sentence)[0]
        e2 = re.findall(r'<e2>(.*)</e2>', sentence)[0]
        sentence = sentence.replace('<e1>' + e1 + '</e1>', ' <e1> ' + e1 + ' </e1> ', 1)
        sentence = sentence.replace('<e2>' + e2 + '</e2>', ' <e2> ' + e2 + ' </e2> ', 1)
        sentence = word_tokenize(sentence)
        sentence = ' '.join(sentence)
        sentence = sentence.replace('< e1 >', '<e1>')
        sentence = sentence.replace('< e2 >', '<e2>')
        sentence = sentence.replace('< /e1 >', '</e1>')
        sentence = sentence.replace('< /e2 >', '</e2>')
        sentence = sentence.split()

        assert '<e1>' in sentence
        assert '<e2>' in sentence
        assert '</e1>' in sentence
        assert '</e2>' in sentence

        subj_start = subj_end = obj_start = obj_end = 0
        pure_sentence = []
        for i, word in enumerate(sentence):
            if '<e1>' == word:
                subj_start = len(pure_sentence)
                continue
            if '</e1>' == word:
                subj_end = len(pure_sentence) - 1
                continue
            if '<e2>' == word:
                obj_start = len(pure_sentence)
                continue
            if '</e2>' == word:
                obj_end = len(pure_sentence) - 1
                continue
            pure_sentence.append(word)
        return e1, e2, subj_start, subj_end, obj_start, obj_end, pure_sentence

    def convert(self, path_src, path_des):
        with open(path_src, 'r', encoding='utf-8') as fr:
            data = fr.readlines()
        with open(path_des, 'w', encoding='utf-8') as fw:
            for i in range(0, len(data), 4):
                id_s, sentence = data[i].strip().split('\t')
                sentence = sentence[1:-1]
                e1, e2, subj_start, subj_end, obj_start, obj_end, sentence = self.search_entity(sentence)
                meta = dict(
                    id=id_s,
                    relation=data[i + 1].strip(),
                    head=e1,
                    tail=e2,
                    subj_start=subj_start,
                    subj_end=subj_end,
                    obj_start=obj_start,
                    obj_end=obj_end,
                    sentence=sentence,
                    comment=data[i + 2].strip()[8:]
                )
                json.dump(meta, fw, ensure_ascii=False)
                fw.write('\n')

train.json

{"id": "1", "relation": "Component-Whole(e2,e1)", "head": "configuration", "tail": "elements", "subj_start": 12, "subj_end": 12, "obj_start": 15, "obj_end": 15, "sentence": ["The", "system", "as", "described", "above", "has", "its", "greatest", "application", "in", "an", "arrayed", "configuration", "of", "antenna", "elements", "."], "comment": " Not a collection: there is structure here, organisation."}
{"id": "2", "relation": "Other", "head": "child", "tail": "cradle", "subj_start": 1, "subj_end": 1, "obj_start": 9, "obj_end": 9, "sentence": ["The", "child", "was", "carefully", "wrapped", "and", "bound", "into", "the", "cradle", "by", "means", "of", "a", "cord", "."], "comment": ""}
{"id": "3", "relation": "Instrument-Agency(e2,e1)", "head": "author", "tail": "disassembler", "subj_start": 1, "subj_end": 1, "obj_start": 7, "obj_end": 7, "sentence": ["The", "author", "of", "a", "keygen", "uses", "a", "disassembler", "to", "look", "at", "the", "raw", "assembly", "code", "."], "comment": ""}
{"id": "4", "relation": "Other", "head": "ridge", "tail": "surge", "subj_start": 2, "subj_end": 2, "obj_start": 6, "obj_end": 6, "sentence": ["A", "misty", "ridge", "uprises", "from", "the", "surge", "."], "comment": ""}
{"id": "5", "relation": "Member-Collection(e1,e2)", "head": "student", "tail": "association", "subj_start": 1, "subj_end": 1, "obj_start": 2, "obj_end": 2, "sentence": ["The", "student", "association", "is", "the", "voice", "of", "the", "undergraduate", "student", "population", "of", "the", "State", "University", "of", "New", "York", "at", "Buffalo", "."], "comment": ""}
{"id": "6", "relation": "Other", "head": "complex", "tail": "producer", "subj_start": 4, "subj_end": 4, "obj_start": 10, "obj_end": 10, "sentence": ["This", "is", "the", "sprawling", "complex", "that", "is", "Peru", "'s", "largest", "producer", "of", "silver", "."], "comment": ""}
{"id": "7", "relation": "Cause-Effect(e2,e1)", "head": "inflammation", "tail": "infection", "subj_start": 7, "subj_end": 7, "obj_start": 19, "obj_end": 19, "sentence": ["The", "current", "view", "is", "that", "the", "chronic", "inflammation", "in", "the", "distal", "part", "of", "the", "stomach", "caused", "by", "Helicobacter", "pylori", "infection", "results", "in", "an", "increased", "acid", "production", "from", "the", "non-infected", "upper", "corpus", "region", "of", "the", "stomach", "."], "comment": ""}
{"id": "8", "relation": "Entity-Destination(e1,e2)", "head": "People", "tail": "downtown", "subj_start": 0, "subj_end": 0, "obj_start": 6, "obj_end": 6, "sentence": ["People", "have", "been", "moving", "back", "into", "downtown", "."], "comment": ""}
{"id": "9", "relation": "Content-Container(e1,e2)", "head": "lawsonite", "tail": "platinum crucible", "subj_start": 1, "subj_end": 1, "obj_start": 6, "obj_end": 7, "sentence": ["The", "lawsonite", "was", "contained", "in", "a", "platinum", "crucible", "and", "the", "counter-weight", "was", "a", "plastic", "crucible", "with", "metal", "pieces", "."], "comment": " prototypical example"}
{"id": "10", "relation": "Entity-Destination(e1,e2)", "head": "solvent", "tail": "flask", "subj_start": 12, "subj_end": 12, "obj_start": 20, "obj_end": 20, "sentence": ["The", "solute", "was", "placed", "inside", "a", "beaker", "and", "5", "mL", "of", "the", "solvent", "was", "pipetted", "into", "a", "25", "mL", "glass", "flask", "for", "each", "trial", "."], "comment": ""}
......

test.json

{"id": "8001", "relation": "Message-Topic(e1,e2)", "head": "audits", "tail": "waste", "subj_start": 3, "subj_end": 3, "obj_start": 6, "obj_end": 6, "sentence": ["The", "most", "common", "audits", "were", "about", "waste", "and", "recycling", "."], "comment": " Assuming an audit = an audit document."}
{"id": "8002", "relation": "Product-Producer(e2,e1)", "head": "company", "tail": "chairs", "subj_start": 1, "subj_end": 1, "obj_start": 4, "obj_end": 4, "sentence": ["The", "company", "fabricates", "plastic", "chairs", "."], "comment": " (a) is satisfied"}
{"id": "8003", "relation": "Instrument-Agency(e2,e1)", "head": "master", "tail": "stick", "subj_start": 2, "subj_end": 2, "obj_start": 8, "obj_end": 8, "sentence": ["The", "school", "master", "teaches", "the", "lesson", "with", "a", "stick", "."], "comment": ""}
{"id": "8004", "relation": "Entity-Destination(e1,e2)", "head": "body", "tail": "reservoir", "subj_start": 5, "subj_end": 5, "obj_start": 9, "obj_end": 9, "sentence": ["The", "suspect", "dumped", "the", "dead", "body", "into", "a", "local", "reservoir", "."], "comment": ""}
{"id": "8005", "relation": "Cause-Effect(e2,e1)", "head": "influenza", "tail": "virus", "subj_start": 1, "subj_end": 1, "obj_start": 16, "obj_end": 16, "sentence": ["Avian", "influenza", "is", "an", "infectious", "disease", "of", "birds", "caused", "by", "type", "A", "strains", "of", "the", "influenza", "virus", "."], "comment": ""}
{"id": "8006", "relation": "Component-Whole(e1,e2)", "head": "ear", "tail": "elephant", "subj_start": 1, "subj_end": 1, "obj_start": 5, "obj_end": 5, "sentence": ["The", "ear", "of", "the", "African", "elephant", "is", "significantly", "larger", "--", "measuring", "183", "cm", "by", "114", "cm", "in", "the", "bush", "elephant", "."], "comment": ""}
{"id": "8007", "relation": "Product-Producer(e1,e2)", "head": "lie", "tail": "parents", "subj_start": 5, "subj_end": 5, "obj_start": 11, "obj_end": 11, "sentence": ["A", "child", "is", "told", "a", "lie", "for", "several", "years", "by", "their", "parents", "before", "he/she", "realizes", "that", "a", "Santa", "Claus", "does", "not", "exist", "."], "comment": " (a) is satisfied; negation is outside"}
{"id": "8008", "relation": "Member-Collection(e2,e1)", "head": "hookup", "tail": "users", "subj_start": 8, "subj_end": 8, "obj_start": 12, "obj_end": 12, "sentence": ["Skype", ",", "a", "free", "software", ",", "allows", "a", "hookup", "of", "multiple", "computer", "users", "to", "join", "in", "an", "online", "conference", "call", "without", "incurring", "any", "telephone", "costs", "."], "comment": ""}
{"id": "8009", "relation": "Component-Whole(e1,e2)", "head": "room", "tail": "house", "subj_start": 12, "subj_end": 12, "obj_start": 16, "obj_end": 16, "sentence": ["The", "disgusting", "scene", "was", "retaliation", "against", "her", "brother", "Philip", "who", "rents", "the", "room", "inside", "this", "apartment", "house", "on", "Lombard", "street", "."], "comment": ""}
{"id": "8010", "relation": "Message-Topic(e1,e2)", "head": "thesis", "tail": "clinical characteristics", "subj_start": 1, "subj_end": 1, "obj_start": 4, "obj_end": 5, "sentence": ["This", "thesis", "defines", "the", "clinical", "characteristics", "of", "amyloid", "disease", "."], "comment": " may be we could leave clinical out of e2."}
......

1.3 relation2id

Other	0
Cause-Effect(e1,e2)	1
Cause-Effect(e2,e1)	2
Component-Whole(e1,e2)	3
Component-Whole(e2,e1)	4
Content-Container(e1,e2)	5
Content-Container(e2,e1)	6
Entity-Destination(e1,e2)	7
Entity-Destination(e2,e1)	8
Entity-Origin(e1,e2)	9
Entity-Origin(e2,e1)	10
Instrument-Agency(e1,e2)	11
Instrument-Agency(e2,e1)	12
Member-Collection(e1,e2)	13
Member-Collection(e2,e1)	14
Message-Topic(e1,e2)	15
Message-Topic(e2,e1)	16
Product-Producer(e1,e2)	17
Product-Producer(e2,e1)	18

2、预训练词向量:静态词向量HLBL

hlbl-embeddings-scaled.EMBEDDING_SIZE=50

*UNKNOWN* -0.166038776479 0.104395984608 0.163119732357 0.0899594154863 -0.0192271099805 -0.0417631572501 -0.0163376687927 0.0357616216019 0.0536077591673 0.0127688536503 -0.00284508433021 -0.0626207031228 -0.0379452734015 -0.103548297666 0.0381169119981 0.00199421074321 -0.0474636488659 -0.0127526851513 0.016404178535 -0.12759853361 -0.0292937037717 -0.0512566352549 0.0233097445983 0.0360505083995 0.00229317984472 -0.0771565284227 0.0071461584378 -0.051608090196 -0.0267547654304 0.0492994451068 -0.0531630844999 0.00787191810391 0.082280106873 0.066908641868 -0.0283930612982 0.216840166248 0.164923151267 0.00188498983723 0.0328679039324 -0.00175432516758 0.0614261774935 0.0987773071377 0.0548423375506 -0.0307057922059 0.053074241476 0.04982054279 -0.0572485864016 0.132236444766 -0.0379717035014 -0.120915939814
the -0.0841015569168 0.145263825738 0.116945121935 -0.0754618634155 0.17901499611 -0.000652852605208 -0.0713783879233 0.207273704502 0.060711721477 0.0366727701165 -0.0269791566731 -0.156993473526 -0.0393947453024 0.00749161628231 -0.332851634057 -0.1708430781 -0.275163605231 -0.266592614101 0.43349041466 -0.00779248211778 0.031101796379 -0.0257114150838 0.174856713352 -0.0543054233622 -0.0846669459476 -0.006234398456 0.00414488584462 0.119738648443 -0.0914876936952 -0.317381121871 -0.27471439742 0.234269597998 0.170305945138 -0.0282815073325 -0.10127814458 0.156451476203 0.154703520781 -0.0014827085612 0.164287521114 0.0328582913203 0.0356570354049 -0.190254406793 -0.112029936115 -0.198875312619 0.00102875631152 -0.00161517169984 -0.125210890327 0.196903181061 -0.112017915766 -0.00838804375065
. -0.0875932389444 -0.0586365253633 0.0729727126603 0.32072000431 0.0745620569276 -0.0494709138174 0.208708067552 -0.025035364294 -0.197531050237 0.177318202028 0.297077745222 -0.0256369072571 0.182364658364 0.189089099105 0.0589179494006 -0.0627276310572 0.0682898379459 0.241161712515 0.253510796291 -0.0325139691451 -0.0129081882483 -0.083367340352 0.0276167362372 -0.00757124183183 -0.0905801885623 0.305015208385 0.0755474920504 -0.00516459185438 -0.0412876867803 0.105047372601 -0.718674456034 0.184682477295 0.232732814491 0.0929975692214 0.0999329447708 -0.0968008990987 0.421525505372 -0.136460066398 -0.323294448817 0.118318915141 0.415411774103 -0.135770867168 0.0404792691614 0.264279769529 -0.133076243622 0.195087919022 -0.087589323012 0.0335223022065 -0.0365650611956 -0.0163760300203
, -0.023019838485 0.277215570968 0.241932261453 -0.105403438907 0.247316949736 0.0859618436243 -0.0130132156599 0.123988163629 -0.150741462418 0.129993766762 0.0766431623839 0.0547135456598 0.187342182554 0.176303102861 -0.121401723217 0.0458278230666 0.0339804870854 -0.0619606057248 0.0514787739809 0.00732501266557 0.0879996990484 -0.369288823679 0.235222707122 -0.0528783055204 0.0121891472663 -0.165169815904 -0.136829953355 -0.0750751223049 -0.0503433833321 0.0782539868365 -0.400940778018 -0.099745222007 -0.152448498545 -0.0815002789835 -0.010575616616 0.331604536668 -0.0124179474775 0.00173559407939 -0.230971231526 0.0162523457081 0.213848645598 0.184698023693 0.158368229826 0.0975422545404 -0.0307127563081 0.093420146492 -0.0377856184872 -0.0181716170654 0.43322993915 -0.113289957059
to 0.134693667961 0.392203653086 0.0346151199225 0.135354475458 0.0719918082372 0.118667933013 -0.0698386234679 -0.0139927084407 0.144452931939 0.0383223273458 -0.0491954394553 -0.126435975874 0.23979196724 -0.186550477314 0.0602616605691 -0.0875395769807 0.0788848675161 0.132691898026 0.155618778336 0.00680378469567 -0.126513561203 -0.436124771467 0.132675129426 -0.0946286638801 0.0986847070674 -0.354397304845 -0.196909463175 -0.0911408611189 0.134975690877 0.0625931974859 0.0108112360985 -0.107933544401 -0.166545488854 0.0137397678012 -0.0268394211932 -0.260328038765 0.0745185746772 0.020864049205 0.133485534344 -0.0479098207297 0.145382061477 -0.116284346216 0.0822848147919 -0.00621959258902 0.0135679910959 -0.0723116375013 -0.422793539068 0.144456402991 -0.119019192402 0.0659297394103
......

3、config.py

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# @Version : Python 3.6

import argparse
import torch
import os
import random
import json
import numpy as np


class Config(object):
    def __init__(self):
        # get init config
        args = self.__get_config()
        for key in args.__dict__:
            setattr(self, key, args.__dict__[key])

        # select device
        self.device = None
        if self.cuda >= 0 and torch.cuda.is_available():
            self.device = torch.device('cuda:{}'.format(self.cuda))
        else:
            self.device = torch.device('cpu')

        # determine the model name and model dir
        if self.model_name is None:
            self.model_name = 'CNN2'
        self.model_dir = os.path.join(self.output_dir, self.model_name)
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)

        # backup data
        self.__config_backup(args)

        # set the random seed
        self.__set_seed(self.seed)

    def __get_config(self):
        parser = argparse.ArgumentParser()
        parser.description = 'config for models'

        # several key selective parameters
        parser.add_argument('--data_dir', type=str,
                            default='./data',
                            help='dir to load data')
        parser.add_argument('--output_dir', type=str,
                            default='./output',
                            help='dir to save output')

        # word embedding
        parser.add_argument('--embedding_path', type=str,
                            default='./embedding/hlbl-embeddings-scaled.EMBEDDING_SIZE=50.txt',
                            help='pre_trained word embedding')
        parser.add_argument('--word_dim', type=int,
                            default=50,
                            help='dimension of word embedding')
        # train settings
        parser.add_argument('--model_name', type=str,
                            default=None,
                            help='model name')
        parser.add_argument('--mode', type=int,
                            default=1,
                            choices=[0, 1],
                            help='running mode: 1 for training; otherwise testing')
        parser.add_argument('--seed', type=int,
                            default=666,
                            help='random seed')
        parser.add_argument('--cuda', type=int,
                            default=0,
                            help='num of gpu device, if -1, select cpu')
        parser.add_argument('--epoch', type=int,
                            default=100,
                            help='max epoches during training')
        # hyper parameters
        parser.add_argument('--dropout', type=float,
                            default=0.5,
                            help='the possiblity of dropout')
        parser.add_argument('--batch_size', type=int,
                            default=128,
                            help='batch size')
        parser.add_argument('--lr', type=float,
                            default=0.001,
                            help='learning rate')
        parser.add_argument('--max_len', type=int,
                            default=96,
                            help='max length of sentence')
        parser.add_argument('--pos_dis', type=int, default=20,
                            help='max distance of position embedding')
        parser.add_argument('--pos_dim', type=int,
                            default=5,
                            help='dimension of position embedding')
        parser.add_argument('--hidden_size', type=int, default=100,
                            help='the size of linear layer between convolution and classification')

        # hyper parameters for cnn
        parser.add_argument('--filter_num', type=int, default=200,
                            help='the number of filters in convolution')
        parser.add_argument('--window', type=int, default=3,
                            help='the size of window in convolution')

        parser.add_argument('--L2_decay', type=float, default=0.0001,
                            help='L2 weight decay')

        args = parser.parse_args()
        return args

    def __set_seed(self, seed=1234):
        os.environ['PYTHONHASHSEED'] = '{}'.format(seed)
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)  # set seed for cpu
        torch.cuda.manual_seed(seed)  # set seed for current gpu
        torch.cuda.manual_seed_all(seed)  # set seed for all gpu

    def __config_backup(self, args):
        config_backup_path = os.path.join(self.model_dir, 'config.json')
        with open(config_backup_path, 'w', encoding='utf-8') as fw:
            json.dump(vars(args), fw, ensure_ascii=False)

    def print_config(self):
        for key in self.__dict__:
            print(key, end=' = ')
            print(self.__dict__[key])


if __name__ == '__main__':
    config = Config()
    config.print_config()

4、dataset.py

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# @Version : Python 3.6


import os
import torch
import numpy as np
import json
import re
from nltk.tokenize import word_tokenize
from torch.utils.data import Dataset, DataLoader


class WordEmbeddingLoader(object):
    """
    A loader for pre-trained word embedding
    """

    def __init__(self, config):
        self.path_word = config.embedding_path  # path of pre-trained word embedding
        self.word_dim = config.word_dim  # dimension of word embedding

    def trim_from_pre_embedding(self, vocab):
        word2id = dict()
        word_vec = {}
        trim_word_vec = list()
        with open(self.path_word, 'r', encoding='utf-8') as fr:
            for line in fr:
                line = line.strip().split()
                if len(line) != self.word_dim + 1:
                    continue
                word_vec[line[0]] = np.asarray(line[1:], dtype=np.float32)
        for word in vocab:
            word2id[word] = len(word2id)
            if (word in word_vec):
                trim_word_vec.append(word_vec[word])
            else:
                trim_word_vec.append(np.random.uniform(-1, 1, self.word_dim))
        # 添加特殊字符
        if ("*UNKNOWN*" not in word2id):
            word2id['*UNKNOWN*'] = len(word2id)
            unk_emb = np.random.uniform(-1, 1, self.word_dim)
            trim_word_vec.append(unk_emb)
        if ("PAD" not in word2id):
            word2id['PAD'] = len(word2id)
            pad_emb = np.zeros(self.word_dim)
            trim_word_vec.append(unk_emb)
        trim_word_vec = np.array(trim_word_vec)
        trim_word_vec = trim_word_vec.astype(np.float32).reshape(-1, self.word_dim)
        return word2id, torch.from_numpy(trim_word_vec)

    def load_embedding(self):
        word2id = dict()  # word to wordID
        word_vec = list()  # wordID to word embedding
        word2id['PAD'] = len(word2id)  # PAD character

        with open(self.path_word, 'r', encoding='utf-8') as fr:
            for line in fr:
                line = line.strip().split()
                if len(line) != self.word_dim + 1:
                    continue
                word2id[line[0]] = len(word2id)
                word_vec.append(np.asarray(line[1:], dtype=np.float32))
        if ("*UNKNOWN*" not in word2id):
            word2id['*UNKNOWN*'] = len(word2id)
            unk_emb = np.random.uniform(-1, 1, self.word_dim)
            word_vec.append(unk_emb)
        pad_emb = np.zeros([1, self.word_dim], dtype=np.float32)  # <pad> is initialize as zero
        word_vec = np.concatenate((pad_emb, word_vec), axis=0)
        word_vec = word_vec.astype(np.float32).reshape(-1, self.word_dim)
        word_vec = torch.from_numpy(word_vec)
        return word2id, word_vec


class RelationLoader(object):
    def __init__(self, config):
        self.data_dir = config.data_dir

    def __load_relation(self):
        relation_file = os.path.join(self.data_dir, 'relation2id.txt')
        rel2id = {}
        id2rel = {}
        with open(relation_file, 'r', encoding='utf-8') as fr:
            for line in fr:
                relation, id_s = line.strip().split()
                id_d = int(id_s)
                rel2id[relation] = id_d
                id2rel[id_d] = relation
        return rel2id, id2rel, len(rel2id)

    def get_relation(self):
        return self.__load_relation()


class SemEvalDateset(Dataset):
    def __init__(self, filename, rel2id, word2id, config):
        self.filename = filename
        self.rel2id = rel2id
        self.word2id = word2id
        self.max_len = config.max_len
        self.pos_dis = config.pos_dis
        self.data_dir = config.data_dir
        self.dataset, self.label = self.__load_data()

    def __get_pos_index(self, x):
        if x < -self.pos_dis:
            return 0
        if x >= -self.pos_dis and x <= self.pos_dis:
            return x + self.pos_dis + 1
        if x > self.pos_dis:
            return 2 * self.pos_dis + 2

    def __get_relative_pos(self, x, entity_pos):
        if x < entity_pos[0]:
            return self.__get_pos_index(x - entity_pos[0])
        elif x > entity_pos[1]:
            return self.__get_pos_index(x - entity_pos[1])
        else:
            return self.__get_pos_index(0)

    def __symbolize_sentence(self, e1_pos, e2_pos, sentence):
        """
            Args:
                e1_pos (tuple) span of e1
                e2_pos (tuple) span of e2
                sentence (list)
        """
        mask = [1] * len(sentence)
        if e1_pos[0] < e2_pos[0]:
            for i in range(e1_pos[0], e2_pos[1] + 1):
                mask[i] = 2
            for i in range(e2_pos[1] + 1, len(sentence)):
                mask[i] = 3
        else:
            for i in range(e2_pos[0], e1_pos[1] + 1):
                mask[i] = 2
            for i in range(e1_pos[1] + 1, len(sentence)):
                mask[i] = 3

        words = []
        pos1 = []
        pos2 = []
        length = min(self.max_len, len(sentence))
        mask = mask[:length]

        for i in range(length):
            words.append(self.word2id.get(sentence[i], self.word2id['*UNKNOWN*']))
            pos1.append(self.__get_relative_pos(i, e1_pos))
            pos2.append(self.__get_relative_pos(i, e2_pos))

        if length < self.max_len:
            for i in range(length, self.max_len):
                mask.append(0)  # 'PAD' mask is zero
                words.append(self.word2id['PAD'])

                pos1.append(self.__get_relative_pos(i, e1_pos))
                pos2.append(self.__get_relative_pos(i, e2_pos))
        unit = np.asarray([words, pos1, pos2, mask], dtype=np.int64)
        unit = np.reshape(unit, newshape=(1, 4, self.max_len))
        return unit

    def _lexical_feature(self, e1_idx, e2_idx, sent):
        def _entity_context(e_idx, sent):
            ''' return [w(e-1), w(e), w(e+1)]
            '''
            context = []
            context.append(sent[e_idx])
            if e_idx >= 1:
                context.append(sent[e_idx - 1])
            else:
                context.append(sent[e_idx])

            if e_idx < len(sent) - 1:
                context.append(sent[e_idx + 1])
            else:
                context.append(sent[e_idx])
            return context

        # print(e1_idx,sent)
        context1 = _entity_context(e1_idx[0], sent)
        context2 = _entity_context(e2_idx[0], sent)
        # ignore WordNet hypernyms in paper
        lexical = context1 + context2
        lexical_ids = [self.word2id.get(word, self.word2id['*UNKNOWN*']) for word in lexical]
        lexical_ids = np.asarray(lexical_ids, dtype=np.int64)
        return np.reshape(lexical_ids, newshape=(1, 6))

    def __load_data(self):
        path_data_file = os.path.join(self.data_dir, self.filename)
        data = []
        labels = []
        with open(path_data_file, 'r', encoding='utf-8') as fr:
            for line in fr:
                line = json.loads(line.strip())
                label = line['relation']
                sentence = line['sentence']
                e1_pos = (line['subj_start'], line['subj_end'])
                e2_pos = (line['obj_start'], line['obj_end'])
                label_idx = self.rel2id[label]

                one_sentence = self.__symbolize_sentence(e1_pos, e2_pos, sentence)
                lexical = self._lexical_feature(e1_pos, e2_pos, sentence)
                temp = (one_sentence, lexical)
                data.append(temp)
                # data.append(one_sentence)
                labels.append(label_idx)
        return data, labels

    def __getitem__(self, index):
        data = self.dataset[index]
        label = self.label[index]
        return data, label

    def __len__(self):
        return len(self.label)


class SemEvalDataLoader(object):
    def __init__(self, rel2id, word2id, config):
        self.rel2id = rel2id
        self.word2id = word2id
        self.config = config

    def __collate_fn(self, batch):
        data, label = zip(*batch)  # unzip the batch data
        data = list(data)
        label = list(label)
        sentence_feat = torch.from_numpy(np.concatenate([x[0] for x in data], axis=0))
        lexical_feat = torch.from_numpy(np.concatenate([x[1] for x in data], axis=0))
        label = torch.from_numpy(np.asarray(label, dtype=np.int64))
        return (sentence_feat, lexical_feat), label

    def __get_data(self, filename, shuffle=False):
        dataset = SemEvalDateset(filename, self.rel2id, self.word2id, self.config)
        loader = DataLoader(
            dataset=dataset,
            batch_size=self.config.batch_size,
            shuffle=shuffle,
            num_workers=2,
            collate_fn=self.__collate_fn
        )
        return loader

    def get_train(self):
        return self.__get_data('train.json', shuffle=True)

    def get_dev(self):
        return self.__get_data('test.json', shuffle=False)

    def get_test(self):
        return self.__get_data('test.json', shuffle=False)


class processor(object):
    def __init__(self):
        pass

    def search_entity(self, sentence):
        e1 = re.findall(r'<e1>(.*)</e1>', sentence)[0]
        e2 = re.findall(r'<e2>(.*)</e2>', sentence)[0]
        sentence = sentence.replace('<e1>' + e1 + '</e1>', ' <e1> ' + e1 + ' </e1> ', 1)
        sentence = sentence.replace('<e2>' + e2 + '</e2>', ' <e2> ' + e2 + ' </e2> ', 1)
        sentence = word_tokenize(sentence)
        sentence = ' '.join(sentence)
        sentence = sentence.replace('< e1 >', '<e1>')
        sentence = sentence.replace('< e2 >', '<e2>')
        sentence = sentence.replace('< /e1 >', '</e1>')
        sentence = sentence.replace('< /e2 >', '</e2>')
        sentence = sentence.split()

        assert '<e1>' in sentence
        assert '<e2>' in sentence
        assert '</e1>' in sentence
        assert '</e2>' in sentence

        subj_start = subj_end = obj_start = obj_end = 0
        pure_sentence = []
        for i, word in enumerate(sentence):
            if '<e1>' == word:
                subj_start = len(pure_sentence)
                continue
            if '</e1>' == word:
                subj_end = len(pure_sentence) - 1
                continue
            if '<e2>' == word:
                obj_start = len(pure_sentence)
                continue
            if '</e2>' == word:
                obj_end = len(pure_sentence) - 1
                continue
            pure_sentence.append(word)
        return e1, e2, subj_start, subj_end, obj_start, obj_end, pure_sentence

    def convert(self, path_src, path_des):
        with open(path_src, 'r', encoding='utf-8') as fr:
            data = fr.readlines()
        with open(path_des, 'w', encoding='utf-8') as fw:
            for i in range(0, len(data), 4):
                id_s, sentence = data[i].strip().split('\t')
                sentence = sentence[1:-1]
                e1, e2, subj_start, subj_end, obj_start, obj_end, sentence = self.search_entity(sentence)
                meta = dict(
                    id=id_s,
                    relation=data[i + 1].strip(),
                    head=e1,
                    tail=e2,
                    subj_start=subj_start,
                    subj_end=subj_end,
                    obj_start=obj_start,
                    obj_end=obj_end,
                    sentence=sentence,
                    comment=data[i + 2].strip()[8:]
                )
                json.dump(meta, fw, ensure_ascii=False)
                fw.write('\n')


class VocabGenerator(object):
    def __init__(self, train_path, test_path):
        self.train_path = train_path
        self.test_path = test_path

    def get_vocab(self):
        vocab = {}
        with open(self.train_path, 'r', encoding='utf-8') as fr:
            for line in fr:
                line = json.loads(line.strip())
                sentence = line['sentence']
                for word in sentence:
                    vocab[word] = 1
        with open(self.test_path, 'r', encoding='utf-8') as fr:
            for line in fr:
                line = json.loads(line.strip())
                sentence = line['sentence']
                for word in sentence:
                    vocab[word] = 1

        return vocab.keys()


if __name__ == '__main__':
    path_train = 'data/train_file.txt'
    path_test = 'data/test_file.txt'
    processor1 = processor()
    processor1.convert(path_train, 'data/train.json')
    processor1.convert(path_test, 'data/test.json')
    vocab = VocabGenerator('data/train.json', 'data/test.json').get_vocab()

    # from config import Config
    # config = Config()
    # word2id, word_vec = WordEmbeddingLoader(config).load_embedding()
    # rel2id, id2rel, class_num = RelationLoader(config).get_relation()
    # loader = SemEvalDataLoader(rel2id, word2id, config)
    # test_loader = loader.get_train()
    #
    # min_v, max_v = float('inf'), -float('inf')
    # for step, (data, label) in enumerate(test_loader):
    #     # print(type(data), data.shape)
    #     # print(type(label), label.shape)
    #     # break
    #     pos1 = data[:, 1, :].view(-1, config.max_len)
    #     pos2 = data[:, 2, :].view(-1, config.max_len)
    #     mask = data[:, 3, :].view(-1, config.max_len)
    #     min_v = min(min_v, torch.min(pos1).item())
    #     max_v = max(max_v, torch.max(pos1).item())
    #     min_v = min(min_v, torch.min(pos2).item())
    #     max_v = max(max_v, torch.max(pos2).item())
    # print(min_v, max_v)

5、model.py

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# @Version : Python 3.6

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init


class CNN(nn.Module):
    def __init__(self, word_vec, class_num, config):
        super().__init__()
        self.word_vec = word_vec
        self.class_num = class_num

        # hyper parameters and others
        self.max_len = config.max_len
        self.word_dim = config.word_dim
        self.pos_dim = config.pos_dim
        self.pos_dis = config.pos_dis

        self.dropout_value = config.dropout
        self.filter_num = config.filter_num
        self.window = config.window
        self.hidden_size = config.hidden_size

        self.dim = self.word_dim + 2 * self.pos_dim

        # net structures and operations
        self.word_embedding = nn.Embedding.from_pretrained(embeddings=self.word_vec, freeze=False, )
        self.pos1_embedding = nn.Embedding(num_embeddings=2 * self.pos_dis + 3, embedding_dim=self.pos_dim)
        self.pos2_embedding = nn.Embedding(num_embeddings=2 * self.pos_dis + 3, embedding_dim=self.pos_dim)

        self.conv = nn.Conv2d(
            in_channels=1,
            out_channels=self.filter_num,
            kernel_size=(self.window, self.dim),
            stride=(1, 1),
            bias=False,
            padding=(1, 0),  # same padding
            padding_mode='zeros'
        )
        self.maxpool = nn.MaxPool2d((self.max_len, 1))
        self.tanh = nn.Tanh()
        self.dropout = nn.Dropout(self.dropout_value)
        self.linear = nn.Linear(in_features=self.filter_num, out_features=self.hidden_size, bias=False)
        self.dense = nn.Linear(in_features=self.hidden_size + 6 * self.word_dim, out_features=self.class_num, bias=False)

        # initialize weight
        init.xavier_normal_(self.pos1_embedding.weight)
        init.xavier_normal_(self.pos2_embedding.weight)
        init.xavier_normal_(self.conv.weight)
        # init.constant_(self.conv.bias, 0.)
        init.xavier_normal_(self.linear.weight)
        # init.constant_(self.linear.bias, 0.)
        init.xavier_normal_(self.dense.weight)
        # init.constant_(self.dense.bias, 0.)

    def encoder_layer(self, token, pos1, pos2):
        word_emb = self.word_embedding(token)  # B*L*word_dim
        pos1_emb = self.pos1_embedding(pos1)  # B*L*pos_dim
        pos2_emb = self.pos2_embedding(pos2)  # B*L*pos_dim
        emb = torch.cat(tensors=[word_emb, pos1_emb, pos2_emb], dim=-1)
        return emb  # B*L*D, D=word_dim+2*pos_dim

    def conv_layer(self, emb, mask):
        emb = emb.unsqueeze(dim=1)  # B*1*L*D
        conv = self.conv(emb)  # B*C*L*1

        # mask, remove the effect of 'PAD'
        conv = conv.view(-1, self.filter_num, self.max_len)  # B*C*L
        mask = mask.unsqueeze(dim=1)  # B*1*L
        mask = mask.expand(-1, self.filter_num, -1)  # B*C*L
        conv = conv.masked_fill_(mask.eq(0), float('-inf'))  # B*C*L
        conv = conv.unsqueeze(dim=-1)  # B*C*L*1
        return conv

    def single_maxpool_layer(self, conv):
        pool = self.maxpool(conv)  # B*C*1*1
        pool = pool.view(-1, self.filter_num)  # B*C
        return pool

    def forward(self, data):
        token = data[0][:, 0, :].view(-1, self.max_len)
        pos1 = data[0][:, 1, :].view(-1, self.max_len)
        pos2 = data[0][:, 2, :].view(-1, self.max_len)
        mask = data[0][:, 3, :].view(-1, self.max_len)
        lexical = data[1].view(-1, 6)
        lexical_emb = self.word_embedding(lexical)
        lexical_emb = lexical_emb.view(-1, self.word_dim * 6)
        emb = self.encoder_layer(token, pos1, pos2)
        emb = self.dropout(emb)
        conv = self.conv_layer(emb, mask)
        pool = self.single_maxpool_layer(conv)
        sentence_feature = self.linear(pool)
        sentence_feature = self.tanh(sentence_feature)
        sentence_feature = self.dropout(sentence_feature)
        features = torch.cat((lexical_emb, sentence_feature), 1)
        logits = self.dense(features)
        return logits

6、train_or_test.py

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# @Version : Python 3.6

import os
import torch
import torch.nn as nn
import torch.optim as optim

from config import Config
from dataset import WordEmbeddingLoader, RelationLoader, SemEvalDataLoader, VocabGenerator
from model import CNN
from evaluate import Eval


def print_result(predict_label, id2rel, start_idx=8001):
    with open('script/predicted_result.txt', 'w', encoding='utf-8') as fw:
        for i in range(0, predict_label.shape[0]):
            fw.write('{}\t{}\n'.format(start_idx + i, id2rel[int(predict_label[i])]))


def train(model, criterion, loader, config):
    train_loader, dev_loader, _ = loader
    # weight_decay_list = (param for name, param in model.named_parameters() if name[-4:] != 'bias' and "bn" not in name)
    # no_decay_list = (param for name, param in model.named_parameters() if name[-4:] == 'bias' or "bn" in name)
    # parameters = [{'params': weight_decay_list},
    #               {'params': no_decay_list, 'weight_decay': 0.}]

    optimizer = optim.Adam(model.parameters(), lr=config.lr, weight_decay=config.L2_decay)

    print(model)
    print('traning model parameters:')
    for name, param in model.named_parameters():
        if param.requires_grad:
            print('%s :  %s' % (name, str(param.data.shape)))
    print('--------------------------------------')
    print('start to train the model ...')

    eval_tool = Eval(config)
    min_f1 = -float('inf')
    for epoch in range(1, config.epoch + 1):
        for step, (data, label) in enumerate(train_loader):
            model.train()

            sent_feat = data[0].to(config.device)
            lex_feat = data[1].to(config.device)
            data = (sent_feat, lex_feat)
            # data = data.to(config.device)
            label = label.to(config.device)

            optimizer.zero_grad()
            logits = model(data)
            loss = criterion(logits, label)
            loss.backward()
            optimizer.step()

        _, train_loss, _ = eval_tool.evaluate(model, criterion, train_loader)
        f1, dev_loss, _ = eval_tool.evaluate(model, criterion, dev_loader)

        print('[%03d] train_loss: %.3f | dev_loss: %.3f | micro f1 on dev: %.4f'
              % (epoch, train_loss, dev_loss, f1), end=' ')
        if f1 > min_f1:
            min_f1 = f1
            torch.save(model.state_dict(), os.path.join(config.model_dir, 'model.pkl'))
            print('>>> save models!')
        else:
            print()


def test(model, criterion, loader, config):
    print('--------------------------------------')
    print('start test ...')
    _, _, test_loader = loader
    model.load_state_dict(torch.load(os.path.join(config.model_dir, 'model.pkl')))
    eval_tool = Eval(config)
    f1, test_loss, predict_label = eval_tool.evaluate(model, criterion, test_loader)
    print('test_loss: %.3f | micro f1 on test:  %.4f' % (test_loss, f1))
    return predict_label


if __name__ == '__main__':
    config = Config()
    print('--------------------------------------')
    print('some config:')
    config.print_config()

    print('--------------------------------------')
    print('start to load data ...')
    vocab = VocabGenerator('data/train.json', 'data/test.json').get_vocab()
    word2id, word_vec = WordEmbeddingLoader(config).load_embedding()
    rel2id, id2rel, class_num = RelationLoader(config).get_relation()
    loader = SemEvalDataLoader(rel2id, word2id, config)

    train_loader, dev_loader = None, None
    if config.mode == 1:  # train mode
        train_loader = loader.get_train()
        dev_loader = loader.get_dev()
    test_loader = loader.get_test()
    loader = [train_loader, dev_loader, test_loader]
    print('finish!')

    print('--------------------------------------')
    model = CNN(word_vec=word_vec, class_num=class_num, config=config)
    model = model.to(config.device)
    criterion = nn.CrossEntropyLoss()

    if config.mode == 1:  # train mode
        train(model, criterion, loader, config)
    predict_label = test(model, criterion, loader, config)
    print_result(predict_label, id2rel)

7、evaluate.py

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# @Version : Python 3.6

import numpy as np
import torch


def semeval_scorer(predict_label, true_label, class_num=10):
    import math
    assert true_label.shape[0] == predict_label.shape[0]
    confusion_matrix = np.zeros(shape=[class_num, class_num], dtype=np.float32)
    xDIRx = np.zeros(shape=[class_num], dtype=np.float32)
    for i in range(true_label.shape[0]):
        true_idx = math.ceil(true_label[i]/2)
        predict_idx = math.ceil(predict_label[i]/2)
        if true_label[i] == predict_label[i]:
            confusion_matrix[predict_idx][true_idx] += 1
        else:
            if true_idx == predict_idx:
                xDIRx[predict_idx] += 1
            else:
                confusion_matrix[predict_idx][true_idx] += 1

    col_sum = np.sum(confusion_matrix, axis=0).reshape(-1)
    row_sum = np.sum(confusion_matrix, axis=1).reshape(-1)
    f1 = np.zeros(shape=[class_num], dtype=np.float32)

    for i in range(0, class_num):  # ignore the 'Other'
        try:
            p = float(confusion_matrix[i][i]) / float(col_sum[i] + xDIRx[i])
            r = float(confusion_matrix[i][i]) / float(row_sum[i] + xDIRx[i])
            f1[i] = (2 * p * r / (p + r))
        except:
            pass
    actual_class = 0
    total_f1 = 0.0
    for i in range(1, class_num):
        if f1[i] > 0.0:  # classes that not in the predict label are not considered
            actual_class += 1
            total_f1 += f1[i]
    try:
        macro_f1 = total_f1 / actual_class
    except:
        macro_f1 = 0.0
    return macro_f1


class Eval(object):
    def __init__(self, config):
        self.device = config.device

    def evaluate(self, model, criterion, data_loader):
        predict_label = []
        true_label = []
        total_loss = 0.0
        with torch.no_grad():
            model.eval()
            for _, (data, label) in enumerate(data_loader):
                sent_feat = data[0].to(self.device)
                lex_feat = data[1].to(self.device)
                data = (sent_feat, lex_feat)
                label = label.to(self.device)

                scores = model(data)
                loss = criterion(scores, label)
                total_loss += loss.item() * scores.shape[0]

                scores, pred = torch.max(scores[:, 1:], dim=1)
                pred = pred + 1

                scores = scores.cpu().detach().numpy().reshape((-1, 1))
                pred = pred.cpu().detach().numpy().reshape((-1, 1))
                label = label.cpu().detach().numpy().reshape((-1, 1))

                # During prediction time, a relation is classified as Other
                # only if all actual classes have negative scores.
                # Otherwise, it is classified with the class which has the largest score.
                for i in range(pred.shape[0]):
                    if scores[i][0] < 0:
                        pred[i][0] = 0

                predict_label.append(pred)
                true_label.append(label)
        predict_label = np.concatenate(predict_label, axis=0).reshape(-1).astype(np.int64)
        true_label = np.concatenate(true_label, axis=0).reshape(-1).astype(np.int64)
        eval_loss = total_loss / predict_label.shape[0]

        f1 = semeval_scorer(predict_label, true_label)
        return f1, eval_loss, predict_label




参考资料:
Deep Learning in NLP (一)词向量和语言模型

  • 4
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值