《原始论文:Relation classification via convolutional deep neural network》
一、概述
1、本文idea提出原因
在这篇文章之前,关系提取主要是依靠于统计机器学习方法,他们性能得高低取决于提取特征的好坏。
特征提取又取决于现存NLP系统的输出,会导致误差在现存nlp工具中传播,比较依赖nlp工具
任务依赖关系导致训练比较复杂;
如何利用实体及位置特征,对关系识别做端对端的训练
2、摘要核心
目前关系识别严重依赖于nlp处理工具,可能导致错误累计及传导
提出一种不需要复杂预处理的关系识别方法
实验结果表明该方法是有效的, 达到the state-of-the-art的效果
3、论文成果
提出了CNN的网络结构解决关系端对端识别问题:
- 词汇级特征(lexcial level features)和句子级特征(sentence level features)
- 提出了位置特征(PF,position features),来编码当前词与目标词对的相对距离
- 融合字词向量的信息,更好的上下文提取模型
- 利用SemEval-2010 Task 8 数据集做实验,达到了当时最好的水平
3、历史意义
二、文章代码
1、数据集
1.1 原始数据集
train_file.txt【样本1-8000】
1 "The system as described above has its greatest application in an arrayed <e1>configuration</e1> of antenna <e2>elements</e2>."
Component-Whole(e2,e1)
Comment: Not a collection: there is structure here, organisation.
2 "The <e1>child</e1> was carefully wrapped and bound into the <e2>cradle</e2> by means of a cord."
Other
Comment:
3 "The <e1>author</e1> of a keygen uses a <e2>disassembler</e2> to look at the raw assembly code."
Instrument-Agency(e2,e1)
Comment:
4 "A misty <e1>ridge</e1> uprises from the <e2>surge</e2>."
Other
Comment:
5 "The <e1>student</e1> <e2>association</e2> is the voice of the undergraduate student population of the State University of New York at Buffalo."
Member-Collection(e1,e2)
Comment:
6 "This is the sprawling <e1>complex</e1> that is Peru's largest <e2>producer</e2> of silver."
Other
Comment:
7 "The current view is that the chronic <e1>inflammation</e1> in the distal part of the stomach caused by Helicobacter pylori <e2>infection</e2> results in an increased acid production from the non-infected upper corpus region of the stomach."
Cause-Effect(e2,e1)
Comment:
8 "<e1>People</e1> have been moving back into <e2>downtown</e2>."
Entity-Destination(e1,e2)
Comment:
9 "The <e1>lawsonite</e1> was contained in a <e2>platinum crucible</e2> and the counter-weight was a plastic crucible with metal pieces."
Content-Container(e1,e2)
Comment: prototypical example
10 "The solute was placed inside a beaker and 5 mL of the <e1>solvent</e1> was pipetted into a 25 mL glass <e2>flask</e2> for each trial."
Entity-Destination(e1,e2)
Comment:
......
test_file.txt【样本8001-10717】
8001 "The most common <e1>audits</e1> were about <e2>waste</e2> and recycling."
Message-Topic(e1,e2)
Comment: Assuming an audit = an audit document.
8002 "The <e1>company</e1> fabricates plastic <e2>chairs</e2>."
Product-Producer(e2,e1)
Comment: (a) is satisfied
8003 "The school <e1>master</e1> teaches the lesson with a <e2>stick</e2>."
Instrument-Agency(e2,e1)
Comment:
8004 "The suspect dumped the dead <e1>body</e1> into a local <e2>reservoir</e2>."
Entity-Destination(e1,e2)
Comment:
8005 "Avian <e1>influenza</e1> is an infectious disease of birds caused by type A strains of the influenza <e2>virus</e2>."
Cause-Effect(e2,e1)
Comment:
8006 "The <e1>ear</e1> of the African <e2>elephant</e2> is significantly larger--measuring 183 cm by 114 cm in the bush elephant."
Component-Whole(e1,e2)
Comment:
8007 "A child is told a <e1>lie</e1> for several years by their <e2>parents</e2> before he/she realizes that a Santa Claus does not exist."
Product-Producer(e1,e2)
Comment: (a) is satisfied; negation is outside
8008 "Skype, a free software, allows a <e1>hookup</e1> of multiple computer <e2>users</e2> to join in an online conference call without incurring any telephone costs."
Member-Collection(e2,e1)
Comment:
8009 "The disgusting scene was retaliation against her brother Philip who rents the <e1>room</e1> inside this apartment <e2>house</e2> on Lombard street."
Component-Whole(e1,e2)
Comment:
8010 "This <e1>thesis</e1> defines the <e2>clinical characteristics</e2> of amyloid disease."
Message-Topic(e1,e2)
Comment: may be we could leave clinical out of e2.
1.2 处理后的数据
class processor(object):
def __init__(self):
pass
def search_entity(self, sentence):
e1 = re.findall(r'<e1>(.*)</e1>', sentence)[0]
e2 = re.findall(r'<e2>(.*)</e2>', sentence)[0]
sentence = sentence.replace('<e1>' + e1 + '</e1>', ' <e1> ' + e1 + ' </e1> ', 1)
sentence = sentence.replace('<e2>' + e2 + '</e2>', ' <e2> ' + e2 + ' </e2> ', 1)
sentence = word_tokenize(sentence)
sentence = ' '.join(sentence)
sentence = sentence.replace('< e1 >', '<e1>')
sentence = sentence.replace('< e2 >', '<e2>')
sentence = sentence.replace('< /e1 >', '</e1>')
sentence = sentence.replace('< /e2 >', '</e2>')
sentence = sentence.split()
assert '<e1>' in sentence
assert '<e2>' in sentence
assert '</e1>' in sentence
assert '</e2>' in sentence
subj_start = subj_end = obj_start = obj_end = 0
pure_sentence = []
for i, word in enumerate(sentence):
if '<e1>' == word:
subj_start = len(pure_sentence)
continue
if '</e1>' == word:
subj_end = len(pure_sentence) - 1
continue
if '<e2>' == word:
obj_start = len(pure_sentence)
continue
if '</e2>' == word:
obj_end = len(pure_sentence) - 1
continue
pure_sentence.append(word)
return e1, e2, subj_start, subj_end, obj_start, obj_end, pure_sentence
def convert(self, path_src, path_des):
with open(path_src, 'r', encoding='utf-8') as fr:
data = fr.readlines()
with open(path_des, 'w', encoding='utf-8') as fw:
for i in range(0, len(data), 4):
id_s, sentence = data[i].strip().split('\t')
sentence = sentence[1:-1]
e1, e2, subj_start, subj_end, obj_start, obj_end, sentence = self.search_entity(sentence)
meta = dict(
id=id_s,
relation=data[i + 1].strip(),
head=e1,
tail=e2,
subj_start=subj_start,
subj_end=subj_end,
obj_start=obj_start,
obj_end=obj_end,
sentence=sentence,
comment=data[i + 2].strip()[8:]
)
json.dump(meta, fw, ensure_ascii=False)
fw.write('\n')
train.json
{"id": "1", "relation": "Component-Whole(e2,e1)", "head": "configuration", "tail": "elements", "subj_start": 12, "subj_end": 12, "obj_start": 15, "obj_end": 15, "sentence": ["The", "system", "as", "described", "above", "has", "its", "greatest", "application", "in", "an", "arrayed", "configuration", "of", "antenna", "elements", "."], "comment": " Not a collection: there is structure here, organisation."}
{"id": "2", "relation": "Other", "head": "child", "tail": "cradle", "subj_start": 1, "subj_end": 1, "obj_start": 9, "obj_end": 9, "sentence": ["The", "child", "was", "carefully", "wrapped", "and", "bound", "into", "the", "cradle", "by", "means", "of", "a", "cord", "."], "comment": ""}
{"id": "3", "relation": "Instrument-Agency(e2,e1)", "head": "author", "tail": "disassembler", "subj_start": 1, "subj_end": 1, "obj_start": 7, "obj_end": 7, "sentence": ["The", "author", "of", "a", "keygen", "uses", "a", "disassembler", "to", "look", "at", "the", "raw", "assembly", "code", "."], "comment": ""}
{"id": "4", "relation": "Other", "head": "ridge", "tail": "surge", "subj_start": 2, "subj_end": 2, "obj_start": 6, "obj_end": 6, "sentence": ["A", "misty", "ridge", "uprises", "from", "the", "surge", "."], "comment": ""}
{"id": "5", "relation": "Member-Collection(e1,e2)", "head": "student", "tail": "association", "subj_start": 1, "subj_end": 1, "obj_start": 2, "obj_end": 2, "sentence": ["The", "student", "association", "is", "the", "voice", "of", "the", "undergraduate", "student", "population", "of", "the", "State", "University", "of", "New", "York", "at", "Buffalo", "."], "comment": ""}
{"id": "6", "relation": "Other", "head": "complex", "tail": "producer", "subj_start": 4, "subj_end": 4, "obj_start": 10, "obj_end": 10, "sentence": ["This", "is", "the", "sprawling", "complex", "that", "is", "Peru", "'s", "largest", "producer", "of", "silver", "."], "comment": ""}
{"id": "7", "relation": "Cause-Effect(e2,e1)", "head": "inflammation", "tail": "infection", "subj_start": 7, "subj_end": 7, "obj_start": 19, "obj_end": 19, "sentence": ["The", "current", "view", "is", "that", "the", "chronic", "inflammation", "in", "the", "distal", "part", "of", "the", "stomach", "caused", "by", "Helicobacter", "pylori", "infection", "results", "in", "an", "increased", "acid", "production", "from", "the", "non-infected", "upper", "corpus", "region", "of", "the", "stomach", "."], "comment": ""}
{"id": "8", "relation": "Entity-Destination(e1,e2)", "head": "People", "tail": "downtown", "subj_start": 0, "subj_end": 0, "obj_start": 6, "obj_end": 6, "sentence": ["People", "have", "been", "moving", "back", "into", "downtown", "."], "comment": ""}
{"id": "9", "relation": "Content-Container(e1,e2)", "head": "lawsonite", "tail": "platinum crucible", "subj_start": 1, "subj_end": 1, "obj_start": 6, "obj_end": 7, "sentence": ["The", "lawsonite", "was", "contained", "in", "a", "platinum", "crucible", "and", "the", "counter-weight", "was", "a", "plastic", "crucible", "with", "metal", "pieces", "."], "comment": " prototypical example"}
{"id": "10", "relation": "Entity-Destination(e1,e2)", "head": "solvent", "tail": "flask", "subj_start": 12, "subj_end": 12, "obj_start": 20, "obj_end": 20, "sentence": ["The", "solute", "was", "placed", "inside", "a", "beaker", "and", "5", "mL", "of", "the", "solvent", "was", "pipetted", "into", "a", "25", "mL", "glass", "flask", "for", "each", "trial", "."], "comment": ""}
......
test.json
{"id": "8001", "relation": "Message-Topic(e1,e2)", "head": "audits", "tail": "waste", "subj_start": 3, "subj_end": 3, "obj_start": 6, "obj_end": 6, "sentence": ["The", "most", "common", "audits", "were", "about", "waste", "and", "recycling", "."], "comment": " Assuming an audit = an audit document."}
{"id": "8002", "relation": "Product-Producer(e2,e1)", "head": "company", "tail": "chairs", "subj_start": 1, "subj_end": 1, "obj_start": 4, "obj_end": 4, "sentence": ["The", "company", "fabricates", "plastic", "chairs", "."], "comment": " (a) is satisfied"}
{"id": "8003", "relation": "Instrument-Agency(e2,e1)", "head": "master", "tail": "stick", "subj_start": 2, "subj_end": 2, "obj_start": 8, "obj_end": 8, "sentence": ["The", "school", "master", "teaches", "the", "lesson", "with", "a", "stick", "."], "comment": ""}
{"id": "8004", "relation": "Entity-Destination(e1,e2)", "head": "body", "tail": "reservoir", "subj_start": 5, "subj_end": 5, "obj_start": 9, "obj_end": 9, "sentence": ["The", "suspect", "dumped", "the", "dead", "body", "into", "a", "local", "reservoir", "."], "comment": ""}
{"id": "8005", "relation": "Cause-Effect(e2,e1)", "head": "influenza", "tail": "virus", "subj_start": 1, "subj_end": 1, "obj_start": 16, "obj_end": 16, "sentence": ["Avian", "influenza", "is", "an", "infectious", "disease", "of", "birds", "caused", "by", "type", "A", "strains", "of", "the", "influenza", "virus", "."], "comment": ""}
{"id": "8006", "relation": "Component-Whole(e1,e2)", "head": "ear", "tail": "elephant", "subj_start": 1, "subj_end": 1, "obj_start": 5, "obj_end": 5, "sentence": ["The", "ear", "of", "the", "African", "elephant", "is", "significantly", "larger", "--", "measuring", "183", "cm", "by", "114", "cm", "in", "the", "bush", "elephant", "."], "comment": ""}
{"id": "8007", "relation": "Product-Producer(e1,e2)", "head": "lie", "tail": "parents", "subj_start": 5, "subj_end": 5, "obj_start": 11, "obj_end": 11, "sentence": ["A", "child", "is", "told", "a", "lie", "for", "several", "years", "by", "their", "parents", "before", "he/she", "realizes", "that", "a", "Santa", "Claus", "does", "not", "exist", "."], "comment": " (a) is satisfied; negation is outside"}
{"id": "8008", "relation": "Member-Collection(e2,e1)", "head": "hookup", "tail": "users", "subj_start": 8, "subj_end": 8, "obj_start": 12, "obj_end": 12, "sentence": ["Skype", ",", "a", "free", "software", ",", "allows", "a", "hookup", "of", "multiple", "computer", "users", "to", "join", "in", "an", "online", "conference", "call", "without", "incurring", "any", "telephone", "costs", "."], "comment": ""}
{"id": "8009", "relation": "Component-Whole(e1,e2)", "head": "room", "tail": "house", "subj_start": 12, "subj_end": 12, "obj_start": 16, "obj_end": 16, "sentence": ["The", "disgusting", "scene", "was", "retaliation", "against", "her", "brother", "Philip", "who", "rents", "the", "room", "inside", "this", "apartment", "house", "on", "Lombard", "street", "."], "comment": ""}
{"id": "8010", "relation": "Message-Topic(e1,e2)", "head": "thesis", "tail": "clinical characteristics", "subj_start": 1, "subj_end": 1, "obj_start": 4, "obj_end": 5, "sentence": ["This", "thesis", "defines", "the", "clinical", "characteristics", "of", "amyloid", "disease", "."], "comment": " may be we could leave clinical out of e2."}
......
1.3 relation2id
Other 0
Cause-Effect(e1,e2) 1
Cause-Effect(e2,e1) 2
Component-Whole(e1,e2) 3
Component-Whole(e2,e1) 4
Content-Container(e1,e2) 5
Content-Container(e2,e1) 6
Entity-Destination(e1,e2) 7
Entity-Destination(e2,e1) 8
Entity-Origin(e1,e2) 9
Entity-Origin(e2,e1) 10
Instrument-Agency(e1,e2) 11
Instrument-Agency(e2,e1) 12
Member-Collection(e1,e2) 13
Member-Collection(e2,e1) 14
Message-Topic(e1,e2) 15
Message-Topic(e2,e1) 16
Product-Producer(e1,e2) 17
Product-Producer(e2,e1) 18
2、预训练词向量:静态词向量HLBL
hlbl-embeddings-scaled.EMBEDDING_SIZE=50
*UNKNOWN* -0.166038776479 0.104395984608 0.163119732357 0.0899594154863 -0.0192271099805 -0.0417631572501 -0.0163376687927 0.0357616216019 0.0536077591673 0.0127688536503 -0.00284508433021 -0.0626207031228 -0.0379452734015 -0.103548297666 0.0381169119981 0.00199421074321 -0.0474636488659 -0.0127526851513 0.016404178535 -0.12759853361 -0.0292937037717 -0.0512566352549 0.0233097445983 0.0360505083995 0.00229317984472 -0.0771565284227 0.0071461584378 -0.051608090196 -0.0267547654304 0.0492994451068 -0.0531630844999 0.00787191810391 0.082280106873 0.066908641868 -0.0283930612982 0.216840166248 0.164923151267 0.00188498983723 0.0328679039324 -0.00175432516758 0.0614261774935 0.0987773071377 0.0548423375506 -0.0307057922059 0.053074241476 0.04982054279 -0.0572485864016 0.132236444766 -0.0379717035014 -0.120915939814
the -0.0841015569168 0.145263825738 0.116945121935 -0.0754618634155 0.17901499611 -0.000652852605208 -0.0713783879233 0.207273704502 0.060711721477 0.0366727701165 -0.0269791566731 -0.156993473526 -0.0393947453024 0.00749161628231 -0.332851634057 -0.1708430781 -0.275163605231 -0.266592614101 0.43349041466 -0.00779248211778 0.031101796379 -0.0257114150838 0.174856713352 -0.0543054233622 -0.0846669459476 -0.006234398456 0.00414488584462 0.119738648443 -0.0914876936952 -0.317381121871 -0.27471439742 0.234269597998 0.170305945138 -0.0282815073325 -0.10127814458 0.156451476203 0.154703520781 -0.0014827085612 0.164287521114 0.0328582913203 0.0356570354049 -0.190254406793 -0.112029936115 -0.198875312619 0.00102875631152 -0.00161517169984 -0.125210890327 0.196903181061 -0.112017915766 -0.00838804375065
. -0.0875932389444 -0.0586365253633 0.0729727126603 0.32072000431 0.0745620569276 -0.0494709138174 0.208708067552 -0.025035364294 -0.197531050237 0.177318202028 0.297077745222 -0.0256369072571 0.182364658364 0.189089099105 0.0589179494006 -0.0627276310572 0.0682898379459 0.241161712515 0.253510796291 -0.0325139691451 -0.0129081882483 -0.083367340352 0.0276167362372 -0.00757124183183 -0.0905801885623 0.305015208385 0.0755474920504 -0.00516459185438 -0.0412876867803 0.105047372601 -0.718674456034 0.184682477295 0.232732814491 0.0929975692214 0.0999329447708 -0.0968008990987 0.421525505372 -0.136460066398 -0.323294448817 0.118318915141 0.415411774103 -0.135770867168 0.0404792691614 0.264279769529 -0.133076243622 0.195087919022 -0.087589323012 0.0335223022065 -0.0365650611956 -0.0163760300203
, -0.023019838485 0.277215570968 0.241932261453 -0.105403438907 0.247316949736 0.0859618436243 -0.0130132156599 0.123988163629 -0.150741462418 0.129993766762 0.0766431623839 0.0547135456598 0.187342182554 0.176303102861 -0.121401723217 0.0458278230666 0.0339804870854 -0.0619606057248 0.0514787739809 0.00732501266557 0.0879996990484 -0.369288823679 0.235222707122 -0.0528783055204 0.0121891472663 -0.165169815904 -0.136829953355 -0.0750751223049 -0.0503433833321 0.0782539868365 -0.400940778018 -0.099745222007 -0.152448498545 -0.0815002789835 -0.010575616616 0.331604536668 -0.0124179474775 0.00173559407939 -0.230971231526 0.0162523457081 0.213848645598 0.184698023693 0.158368229826 0.0975422545404 -0.0307127563081 0.093420146492 -0.0377856184872 -0.0181716170654 0.43322993915 -0.113289957059
to 0.134693667961 0.392203653086 0.0346151199225 0.135354475458 0.0719918082372 0.118667933013 -0.0698386234679 -0.0139927084407 0.144452931939 0.0383223273458 -0.0491954394553 -0.126435975874 0.23979196724 -0.186550477314 0.0602616605691 -0.0875395769807 0.0788848675161 0.132691898026 0.155618778336 0.00680378469567 -0.126513561203 -0.436124771467 0.132675129426 -0.0946286638801 0.0986847070674 -0.354397304845 -0.196909463175 -0.0911408611189 0.134975690877 0.0625931974859 0.0108112360985 -0.107933544401 -0.166545488854 0.0137397678012 -0.0268394211932 -0.260328038765 0.0745185746772 0.020864049205 0.133485534344 -0.0479098207297 0.145382061477 -0.116284346216 0.0822848147919 -0.00621959258902 0.0135679910959 -0.0723116375013 -0.422793539068 0.144456402991 -0.119019192402 0.0659297394103
......
3、config.py
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# @Version : Python 3.6
import argparse
import torch
import os
import random
import json
import numpy as np
class Config(object):
def __init__(self):
# get init config
args = self.__get_config()
for key in args.__dict__:
setattr(self, key, args.__dict__[key])
# select device
self.device = None
if self.cuda >= 0 and torch.cuda.is_available():
self.device = torch.device('cuda:{}'.format(self.cuda))
else:
self.device = torch.device('cpu')
# determine the model name and model dir
if self.model_name is None:
self.model_name = 'CNN2'
self.model_dir = os.path.join(self.output_dir, self.model_name)
if not os.path.exists(self.model_dir):
os.makedirs(self.model_dir)
# backup data
self.__config_backup(args)
# set the random seed
self.__set_seed(self.seed)
def __get_config(self):
parser = argparse.ArgumentParser()
parser.description = 'config for models'
# several key selective parameters
parser.add_argument('--data_dir', type=str,
default='./data',
help='dir to load data')
parser.add_argument('--output_dir', type=str,
default='./output',
help='dir to save output')
# word embedding
parser.add_argument('--embedding_path', type=str,
default='./embedding/hlbl-embeddings-scaled.EMBEDDING_SIZE=50.txt',
help='pre_trained word embedding')
parser.add_argument('--word_dim', type=int,
default=50,
help='dimension of word embedding')
# train settings
parser.add_argument('--model_name', type=str,
default=None,
help='model name')
parser.add_argument('--mode', type=int,
default=1,
choices=[0, 1],
help='running mode: 1 for training; otherwise testing')
parser.add_argument('--seed', type=int,
default=666,
help='random seed')
parser.add_argument('--cuda', type=int,
default=0,
help='num of gpu device, if -1, select cpu')
parser.add_argument('--epoch', type=int,
default=100,
help='max epoches during training')
# hyper parameters
parser.add_argument('--dropout', type=float,
default=0.5,
help='the possiblity of dropout')
parser.add_argument('--batch_size', type=int,
default=128,
help='batch size')
parser.add_argument('--lr', type=float,
default=0.001,
help='learning rate')
parser.add_argument('--max_len', type=int,
default=96,
help='max length of sentence')
parser.add_argument('--pos_dis', type=int, default=20,
help='max distance of position embedding')
parser.add_argument('--pos_dim', type=int,
default=5,
help='dimension of position embedding')
parser.add_argument('--hidden_size', type=int, default=100,
help='the size of linear layer between convolution and classification')
# hyper parameters for cnn
parser.add_argument('--filter_num', type=int, default=200,
help='the number of filters in convolution')
parser.add_argument('--window', type=int, default=3,
help='the size of window in convolution')
parser.add_argument('--L2_decay', type=float, default=0.0001,
help='L2 weight decay')
args = parser.parse_args()
return args
def __set_seed(self, seed=1234):
os.environ['PYTHONHASHSEED'] = '{}'.format(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed) # set seed for cpu
torch.cuda.manual_seed(seed) # set seed for current gpu
torch.cuda.manual_seed_all(seed) # set seed for all gpu
def __config_backup(self, args):
config_backup_path = os.path.join(self.model_dir, 'config.json')
with open(config_backup_path, 'w', encoding='utf-8') as fw:
json.dump(vars(args), fw, ensure_ascii=False)
def print_config(self):
for key in self.__dict__:
print(key, end=' = ')
print(self.__dict__[key])
if __name__ == '__main__':
config = Config()
config.print_config()
4、dataset.py
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# @Version : Python 3.6
import os
import torch
import numpy as np
import json
import re
from nltk.tokenize import word_tokenize
from torch.utils.data import Dataset, DataLoader
class WordEmbeddingLoader(object):
"""
A loader for pre-trained word embedding
"""
def __init__(self, config):
self.path_word = config.embedding_path # path of pre-trained word embedding
self.word_dim = config.word_dim # dimension of word embedding
def trim_from_pre_embedding(self, vocab):
word2id = dict()
word_vec = {}
trim_word_vec = list()
with open(self.path_word, 'r', encoding='utf-8') as fr:
for line in fr:
line = line.strip().split()
if len(line) != self.word_dim + 1:
continue
word_vec[line[0]] = np.asarray(line[1:], dtype=np.float32)
for word in vocab:
word2id[word] = len(word2id)
if (word in word_vec):
trim_word_vec.append(word_vec[word])
else:
trim_word_vec.append(np.random.uniform(-1, 1, self.word_dim))
# 添加特殊字符
if ("*UNKNOWN*" not in word2id):
word2id['*UNKNOWN*'] = len(word2id)
unk_emb = np.random.uniform(-1, 1, self.word_dim)
trim_word_vec.append(unk_emb)
if ("PAD" not in word2id):
word2id['PAD'] = len(word2id)
pad_emb = np.zeros(self.word_dim)
trim_word_vec.append(unk_emb)
trim_word_vec = np.array(trim_word_vec)
trim_word_vec = trim_word_vec.astype(np.float32).reshape(-1, self.word_dim)
return word2id, torch.from_numpy(trim_word_vec)
def load_embedding(self):
word2id = dict() # word to wordID
word_vec = list() # wordID to word embedding
word2id['PAD'] = len(word2id) # PAD character
with open(self.path_word, 'r', encoding='utf-8') as fr:
for line in fr:
line = line.strip().split()
if len(line) != self.word_dim + 1:
continue
word2id[line[0]] = len(word2id)
word_vec.append(np.asarray(line[1:], dtype=np.float32))
if ("*UNKNOWN*" not in word2id):
word2id['*UNKNOWN*'] = len(word2id)
unk_emb = np.random.uniform(-1, 1, self.word_dim)
word_vec.append(unk_emb)
pad_emb = np.zeros([1, self.word_dim], dtype=np.float32) # <pad> is initialize as zero
word_vec = np.concatenate((pad_emb, word_vec), axis=0)
word_vec = word_vec.astype(np.float32).reshape(-1, self.word_dim)
word_vec = torch.from_numpy(word_vec)
return word2id, word_vec
class RelationLoader(object):
def __init__(self, config):
self.data_dir = config.data_dir
def __load_relation(self):
relation_file = os.path.join(self.data_dir, 'relation2id.txt')
rel2id = {}
id2rel = {}
with open(relation_file, 'r', encoding='utf-8') as fr:
for line in fr:
relation, id_s = line.strip().split()
id_d = int(id_s)
rel2id[relation] = id_d
id2rel[id_d] = relation
return rel2id, id2rel, len(rel2id)
def get_relation(self):
return self.__load_relation()
class SemEvalDateset(Dataset):
def __init__(self, filename, rel2id, word2id, config):
self.filename = filename
self.rel2id = rel2id
self.word2id = word2id
self.max_len = config.max_len
self.pos_dis = config.pos_dis
self.data_dir = config.data_dir
self.dataset, self.label = self.__load_data()
def __get_pos_index(self, x):
if x < -self.pos_dis:
return 0
if x >= -self.pos_dis and x <= self.pos_dis:
return x + self.pos_dis + 1
if x > self.pos_dis:
return 2 * self.pos_dis + 2
def __get_relative_pos(self, x, entity_pos):
if x < entity_pos[0]:
return self.__get_pos_index(x - entity_pos[0])
elif x > entity_pos[1]:
return self.__get_pos_index(x - entity_pos[1])
else:
return self.__get_pos_index(0)
def __symbolize_sentence(self, e1_pos, e2_pos, sentence):
"""
Args:
e1_pos (tuple) span of e1
e2_pos (tuple) span of e2
sentence (list)
"""
mask = [1] * len(sentence)
if e1_pos[0] < e2_pos[0]:
for i in range(e1_pos[0], e2_pos[1] + 1):
mask[i] = 2
for i in range(e2_pos[1] + 1, len(sentence)):
mask[i] = 3
else:
for i in range(e2_pos[0], e1_pos[1] + 1):
mask[i] = 2
for i in range(e1_pos[1] + 1, len(sentence)):
mask[i] = 3
words = []
pos1 = []
pos2 = []
length = min(self.max_len, len(sentence))
mask = mask[:length]
for i in range(length):
words.append(self.word2id.get(sentence[i], self.word2id['*UNKNOWN*']))
pos1.append(self.__get_relative_pos(i, e1_pos))
pos2.append(self.__get_relative_pos(i, e2_pos))
if length < self.max_len:
for i in range(length, self.max_len):
mask.append(0) # 'PAD' mask is zero
words.append(self.word2id['PAD'])
pos1.append(self.__get_relative_pos(i, e1_pos))
pos2.append(self.__get_relative_pos(i, e2_pos))
unit = np.asarray([words, pos1, pos2, mask], dtype=np.int64)
unit = np.reshape(unit, newshape=(1, 4, self.max_len))
return unit
def _lexical_feature(self, e1_idx, e2_idx, sent):
def _entity_context(e_idx, sent):
''' return [w(e-1), w(e), w(e+1)]
'''
context = []
context.append(sent[e_idx])
if e_idx >= 1:
context.append(sent[e_idx - 1])
else:
context.append(sent[e_idx])
if e_idx < len(sent) - 1:
context.append(sent[e_idx + 1])
else:
context.append(sent[e_idx])
return context
# print(e1_idx,sent)
context1 = _entity_context(e1_idx[0], sent)
context2 = _entity_context(e2_idx[0], sent)
# ignore WordNet hypernyms in paper
lexical = context1 + context2
lexical_ids = [self.word2id.get(word, self.word2id['*UNKNOWN*']) for word in lexical]
lexical_ids = np.asarray(lexical_ids, dtype=np.int64)
return np.reshape(lexical_ids, newshape=(1, 6))
def __load_data(self):
path_data_file = os.path.join(self.data_dir, self.filename)
data = []
labels = []
with open(path_data_file, 'r', encoding='utf-8') as fr:
for line in fr:
line = json.loads(line.strip())
label = line['relation']
sentence = line['sentence']
e1_pos = (line['subj_start'], line['subj_end'])
e2_pos = (line['obj_start'], line['obj_end'])
label_idx = self.rel2id[label]
one_sentence = self.__symbolize_sentence(e1_pos, e2_pos, sentence)
lexical = self._lexical_feature(e1_pos, e2_pos, sentence)
temp = (one_sentence, lexical)
data.append(temp)
# data.append(one_sentence)
labels.append(label_idx)
return data, labels
def __getitem__(self, index):
data = self.dataset[index]
label = self.label[index]
return data, label
def __len__(self):
return len(self.label)
class SemEvalDataLoader(object):
def __init__(self, rel2id, word2id, config):
self.rel2id = rel2id
self.word2id = word2id
self.config = config
def __collate_fn(self, batch):
data, label = zip(*batch) # unzip the batch data
data = list(data)
label = list(label)
sentence_feat = torch.from_numpy(np.concatenate([x[0] for x in data], axis=0))
lexical_feat = torch.from_numpy(np.concatenate([x[1] for x in data], axis=0))
label = torch.from_numpy(np.asarray(label, dtype=np.int64))
return (sentence_feat, lexical_feat), label
def __get_data(self, filename, shuffle=False):
dataset = SemEvalDateset(filename, self.rel2id, self.word2id, self.config)
loader = DataLoader(
dataset=dataset,
batch_size=self.config.batch_size,
shuffle=shuffle,
num_workers=2,
collate_fn=self.__collate_fn
)
return loader
def get_train(self):
return self.__get_data('train.json', shuffle=True)
def get_dev(self):
return self.__get_data('test.json', shuffle=False)
def get_test(self):
return self.__get_data('test.json', shuffle=False)
class processor(object):
def __init__(self):
pass
def search_entity(self, sentence):
e1 = re.findall(r'<e1>(.*)</e1>', sentence)[0]
e2 = re.findall(r'<e2>(.*)</e2>', sentence)[0]
sentence = sentence.replace('<e1>' + e1 + '</e1>', ' <e1> ' + e1 + ' </e1> ', 1)
sentence = sentence.replace('<e2>' + e2 + '</e2>', ' <e2> ' + e2 + ' </e2> ', 1)
sentence = word_tokenize(sentence)
sentence = ' '.join(sentence)
sentence = sentence.replace('< e1 >', '<e1>')
sentence = sentence.replace('< e2 >', '<e2>')
sentence = sentence.replace('< /e1 >', '</e1>')
sentence = sentence.replace('< /e2 >', '</e2>')
sentence = sentence.split()
assert '<e1>' in sentence
assert '<e2>' in sentence
assert '</e1>' in sentence
assert '</e2>' in sentence
subj_start = subj_end = obj_start = obj_end = 0
pure_sentence = []
for i, word in enumerate(sentence):
if '<e1>' == word:
subj_start = len(pure_sentence)
continue
if '</e1>' == word:
subj_end = len(pure_sentence) - 1
continue
if '<e2>' == word:
obj_start = len(pure_sentence)
continue
if '</e2>' == word:
obj_end = len(pure_sentence) - 1
continue
pure_sentence.append(word)
return e1, e2, subj_start, subj_end, obj_start, obj_end, pure_sentence
def convert(self, path_src, path_des):
with open(path_src, 'r', encoding='utf-8') as fr:
data = fr.readlines()
with open(path_des, 'w', encoding='utf-8') as fw:
for i in range(0, len(data), 4):
id_s, sentence = data[i].strip().split('\t')
sentence = sentence[1:-1]
e1, e2, subj_start, subj_end, obj_start, obj_end, sentence = self.search_entity(sentence)
meta = dict(
id=id_s,
relation=data[i + 1].strip(),
head=e1,
tail=e2,
subj_start=subj_start,
subj_end=subj_end,
obj_start=obj_start,
obj_end=obj_end,
sentence=sentence,
comment=data[i + 2].strip()[8:]
)
json.dump(meta, fw, ensure_ascii=False)
fw.write('\n')
class VocabGenerator(object):
def __init__(self, train_path, test_path):
self.train_path = train_path
self.test_path = test_path
def get_vocab(self):
vocab = {}
with open(self.train_path, 'r', encoding='utf-8') as fr:
for line in fr:
line = json.loads(line.strip())
sentence = line['sentence']
for word in sentence:
vocab[word] = 1
with open(self.test_path, 'r', encoding='utf-8') as fr:
for line in fr:
line = json.loads(line.strip())
sentence = line['sentence']
for word in sentence:
vocab[word] = 1
return vocab.keys()
if __name__ == '__main__':
path_train = 'data/train_file.txt'
path_test = 'data/test_file.txt'
processor1 = processor()
processor1.convert(path_train, 'data/train.json')
processor1.convert(path_test, 'data/test.json')
vocab = VocabGenerator('data/train.json', 'data/test.json').get_vocab()
# from config import Config
# config = Config()
# word2id, word_vec = WordEmbeddingLoader(config).load_embedding()
# rel2id, id2rel, class_num = RelationLoader(config).get_relation()
# loader = SemEvalDataLoader(rel2id, word2id, config)
# test_loader = loader.get_train()
#
# min_v, max_v = float('inf'), -float('inf')
# for step, (data, label) in enumerate(test_loader):
# # print(type(data), data.shape)
# # print(type(label), label.shape)
# # break
# pos1 = data[:, 1, :].view(-1, config.max_len)
# pos2 = data[:, 2, :].view(-1, config.max_len)
# mask = data[:, 3, :].view(-1, config.max_len)
# min_v = min(min_v, torch.min(pos1).item())
# max_v = max(max_v, torch.max(pos1).item())
# min_v = min(min_v, torch.min(pos2).item())
# max_v = max(max_v, torch.max(pos2).item())
# print(min_v, max_v)
5、model.py
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# @Version : Python 3.6
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
class CNN(nn.Module):
def __init__(self, word_vec, class_num, config):
super().__init__()
self.word_vec = word_vec
self.class_num = class_num
# hyper parameters and others
self.max_len = config.max_len
self.word_dim = config.word_dim
self.pos_dim = config.pos_dim
self.pos_dis = config.pos_dis
self.dropout_value = config.dropout
self.filter_num = config.filter_num
self.window = config.window
self.hidden_size = config.hidden_size
self.dim = self.word_dim + 2 * self.pos_dim
# net structures and operations
self.word_embedding = nn.Embedding.from_pretrained(embeddings=self.word_vec, freeze=False, )
self.pos1_embedding = nn.Embedding(num_embeddings=2 * self.pos_dis + 3, embedding_dim=self.pos_dim)
self.pos2_embedding = nn.Embedding(num_embeddings=2 * self.pos_dis + 3, embedding_dim=self.pos_dim)
self.conv = nn.Conv2d(
in_channels=1,
out_channels=self.filter_num,
kernel_size=(self.window, self.dim),
stride=(1, 1),
bias=False,
padding=(1, 0), # same padding
padding_mode='zeros'
)
self.maxpool = nn.MaxPool2d((self.max_len, 1))
self.tanh = nn.Tanh()
self.dropout = nn.Dropout(self.dropout_value)
self.linear = nn.Linear(in_features=self.filter_num, out_features=self.hidden_size, bias=False)
self.dense = nn.Linear(in_features=self.hidden_size + 6 * self.word_dim, out_features=self.class_num, bias=False)
# initialize weight
init.xavier_normal_(self.pos1_embedding.weight)
init.xavier_normal_(self.pos2_embedding.weight)
init.xavier_normal_(self.conv.weight)
# init.constant_(self.conv.bias, 0.)
init.xavier_normal_(self.linear.weight)
# init.constant_(self.linear.bias, 0.)
init.xavier_normal_(self.dense.weight)
# init.constant_(self.dense.bias, 0.)
def encoder_layer(self, token, pos1, pos2):
word_emb = self.word_embedding(token) # B*L*word_dim
pos1_emb = self.pos1_embedding(pos1) # B*L*pos_dim
pos2_emb = self.pos2_embedding(pos2) # B*L*pos_dim
emb = torch.cat(tensors=[word_emb, pos1_emb, pos2_emb], dim=-1)
return emb # B*L*D, D=word_dim+2*pos_dim
def conv_layer(self, emb, mask):
emb = emb.unsqueeze(dim=1) # B*1*L*D
conv = self.conv(emb) # B*C*L*1
# mask, remove the effect of 'PAD'
conv = conv.view(-1, self.filter_num, self.max_len) # B*C*L
mask = mask.unsqueeze(dim=1) # B*1*L
mask = mask.expand(-1, self.filter_num, -1) # B*C*L
conv = conv.masked_fill_(mask.eq(0), float('-inf')) # B*C*L
conv = conv.unsqueeze(dim=-1) # B*C*L*1
return conv
def single_maxpool_layer(self, conv):
pool = self.maxpool(conv) # B*C*1*1
pool = pool.view(-1, self.filter_num) # B*C
return pool
def forward(self, data):
token = data[0][:, 0, :].view(-1, self.max_len)
pos1 = data[0][:, 1, :].view(-1, self.max_len)
pos2 = data[0][:, 2, :].view(-1, self.max_len)
mask = data[0][:, 3, :].view(-1, self.max_len)
lexical = data[1].view(-1, 6)
lexical_emb = self.word_embedding(lexical)
lexical_emb = lexical_emb.view(-1, self.word_dim * 6)
emb = self.encoder_layer(token, pos1, pos2)
emb = self.dropout(emb)
conv = self.conv_layer(emb, mask)
pool = self.single_maxpool_layer(conv)
sentence_feature = self.linear(pool)
sentence_feature = self.tanh(sentence_feature)
sentence_feature = self.dropout(sentence_feature)
features = torch.cat((lexical_emb, sentence_feature), 1)
logits = self.dense(features)
return logits
6、train_or_test.py
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# @Version : Python 3.6
import os
import torch
import torch.nn as nn
import torch.optim as optim
from config import Config
from dataset import WordEmbeddingLoader, RelationLoader, SemEvalDataLoader, VocabGenerator
from model import CNN
from evaluate import Eval
def print_result(predict_label, id2rel, start_idx=8001):
with open('script/predicted_result.txt', 'w', encoding='utf-8') as fw:
for i in range(0, predict_label.shape[0]):
fw.write('{}\t{}\n'.format(start_idx + i, id2rel[int(predict_label[i])]))
def train(model, criterion, loader, config):
train_loader, dev_loader, _ = loader
# weight_decay_list = (param for name, param in model.named_parameters() if name[-4:] != 'bias' and "bn" not in name)
# no_decay_list = (param for name, param in model.named_parameters() if name[-4:] == 'bias' or "bn" in name)
# parameters = [{'params': weight_decay_list},
# {'params': no_decay_list, 'weight_decay': 0.}]
optimizer = optim.Adam(model.parameters(), lr=config.lr, weight_decay=config.L2_decay)
print(model)
print('traning model parameters:')
for name, param in model.named_parameters():
if param.requires_grad:
print('%s : %s' % (name, str(param.data.shape)))
print('--------------------------------------')
print('start to train the model ...')
eval_tool = Eval(config)
min_f1 = -float('inf')
for epoch in range(1, config.epoch + 1):
for step, (data, label) in enumerate(train_loader):
model.train()
sent_feat = data[0].to(config.device)
lex_feat = data[1].to(config.device)
data = (sent_feat, lex_feat)
# data = data.to(config.device)
label = label.to(config.device)
optimizer.zero_grad()
logits = model(data)
loss = criterion(logits, label)
loss.backward()
optimizer.step()
_, train_loss, _ = eval_tool.evaluate(model, criterion, train_loader)
f1, dev_loss, _ = eval_tool.evaluate(model, criterion, dev_loader)
print('[%03d] train_loss: %.3f | dev_loss: %.3f | micro f1 on dev: %.4f'
% (epoch, train_loss, dev_loss, f1), end=' ')
if f1 > min_f1:
min_f1 = f1
torch.save(model.state_dict(), os.path.join(config.model_dir, 'model.pkl'))
print('>>> save models!')
else:
print()
def test(model, criterion, loader, config):
print('--------------------------------------')
print('start test ...')
_, _, test_loader = loader
model.load_state_dict(torch.load(os.path.join(config.model_dir, 'model.pkl')))
eval_tool = Eval(config)
f1, test_loss, predict_label = eval_tool.evaluate(model, criterion, test_loader)
print('test_loss: %.3f | micro f1 on test: %.4f' % (test_loss, f1))
return predict_label
if __name__ == '__main__':
config = Config()
print('--------------------------------------')
print('some config:')
config.print_config()
print('--------------------------------------')
print('start to load data ...')
vocab = VocabGenerator('data/train.json', 'data/test.json').get_vocab()
word2id, word_vec = WordEmbeddingLoader(config).load_embedding()
rel2id, id2rel, class_num = RelationLoader(config).get_relation()
loader = SemEvalDataLoader(rel2id, word2id, config)
train_loader, dev_loader = None, None
if config.mode == 1: # train mode
train_loader = loader.get_train()
dev_loader = loader.get_dev()
test_loader = loader.get_test()
loader = [train_loader, dev_loader, test_loader]
print('finish!')
print('--------------------------------------')
model = CNN(word_vec=word_vec, class_num=class_num, config=config)
model = model.to(config.device)
criterion = nn.CrossEntropyLoss()
if config.mode == 1: # train mode
train(model, criterion, loader, config)
predict_label = test(model, criterion, loader, config)
print_result(predict_label, id2rel)
7、evaluate.py
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# @Version : Python 3.6
import numpy as np
import torch
def semeval_scorer(predict_label, true_label, class_num=10):
import math
assert true_label.shape[0] == predict_label.shape[0]
confusion_matrix = np.zeros(shape=[class_num, class_num], dtype=np.float32)
xDIRx = np.zeros(shape=[class_num], dtype=np.float32)
for i in range(true_label.shape[0]):
true_idx = math.ceil(true_label[i]/2)
predict_idx = math.ceil(predict_label[i]/2)
if true_label[i] == predict_label[i]:
confusion_matrix[predict_idx][true_idx] += 1
else:
if true_idx == predict_idx:
xDIRx[predict_idx] += 1
else:
confusion_matrix[predict_idx][true_idx] += 1
col_sum = np.sum(confusion_matrix, axis=0).reshape(-1)
row_sum = np.sum(confusion_matrix, axis=1).reshape(-1)
f1 = np.zeros(shape=[class_num], dtype=np.float32)
for i in range(0, class_num): # ignore the 'Other'
try:
p = float(confusion_matrix[i][i]) / float(col_sum[i] + xDIRx[i])
r = float(confusion_matrix[i][i]) / float(row_sum[i] + xDIRx[i])
f1[i] = (2 * p * r / (p + r))
except:
pass
actual_class = 0
total_f1 = 0.0
for i in range(1, class_num):
if f1[i] > 0.0: # classes that not in the predict label are not considered
actual_class += 1
total_f1 += f1[i]
try:
macro_f1 = total_f1 / actual_class
except:
macro_f1 = 0.0
return macro_f1
class Eval(object):
def __init__(self, config):
self.device = config.device
def evaluate(self, model, criterion, data_loader):
predict_label = []
true_label = []
total_loss = 0.0
with torch.no_grad():
model.eval()
for _, (data, label) in enumerate(data_loader):
sent_feat = data[0].to(self.device)
lex_feat = data[1].to(self.device)
data = (sent_feat, lex_feat)
label = label.to(self.device)
scores = model(data)
loss = criterion(scores, label)
total_loss += loss.item() * scores.shape[0]
scores, pred = torch.max(scores[:, 1:], dim=1)
pred = pred + 1
scores = scores.cpu().detach().numpy().reshape((-1, 1))
pred = pred.cpu().detach().numpy().reshape((-1, 1))
label = label.cpu().detach().numpy().reshape((-1, 1))
# During prediction time, a relation is classified as Other
# only if all actual classes have negative scores.
# Otherwise, it is classified with the class which has the largest score.
for i in range(pred.shape[0]):
if scores[i][0] < 0:
pred[i][0] = 0
predict_label.append(pred)
true_label.append(label)
predict_label = np.concatenate(predict_label, axis=0).reshape(-1).astype(np.int64)
true_label = np.concatenate(true_label, axis=0).reshape(-1).astype(np.int64)
eval_loss = total_loss / predict_label.shape[0]
f1 = semeval_scorer(predict_label, true_label)
return f1, eval_loss, predict_label