《原始论文:Attention-based bidirectional long short-term memory networks for relation classification》
一、概述
1、本文idea提出原因
传统的方法中,大多数研究依赖于一些现有的词汇资源(例如WordNet)、NLP系 统或一些手工提取的特征。这样的方法可能导致计算复杂度的增加,并且特征提取工作本身会耗费大量的时间和精力,特征提取质量的对于实验的结果也有很大的影响。
提出了 ATT-BLSTM的网络结构解决关系端对端识别问题
这篇论文从这一角度出发,提出一个基于Attention机制的双向 LSTM神经网络模型进行关系抽取研究,Attention机制能够自动 发现那些对于分类起到关键作用的词,使得这个模型可以从每个句子中捕获最重要的语义信息,它不依赖于任何外部的知识或者NLP系统
2、本论文历史意义
巧妙地在双向LSTM模型中加入Attention机制,用于关系抽取任务,避免了传统的 任务中复杂的特征工程,大大简化了实验过程并得到相当不错的结果,也为相关的研究提供了可操作性的思路
这篇论文的整体的逻辑十分清晰,紧紧围绕研究动机.整篇论文的思路十分简单,模型也一目了然,但是结果表现优秀
3、摘要核心
- 目前关系识别依赖于Mp工具提取特征;
- 提出一种不需要复杂预处理的关系识别方法att-blstm;
- 实验结果表明该方法是有效的,达到the state-of-the-art的效果
二、Attention-BiLSTM模型结构
1、模型结构
ATT-BLSTM网络结构以word embeding为基础,加入实体标识位,通过ATT-BLSTM的结构让模型动态区分关系分类的重要词汇。
As shown in Figure 1, the model proposed in this paper contains five components:
- 输入句子:Input layer: input sentence to this model;
- Embedding layer: map each word into a low dimension vector;
- BiLSTM:LSTM layer: utilize BLSTM to get high level features from step (2);
- Attention layer: produce a weight vector, and merge word-level features from each time step into a sentence-level feature vector, by multiplying the weight vector;
- Output layer: the sentence-level feature vec- tor is finally used for relation classification.
2、Attention 原理
Attention 原理:Attention Mechanism可以帮助模型对输入的X每个部分赋予不同的权重,抽取出更加关键及重要的信息,使模型做出更加准确的判断,同时不会对模型的计算和存储带来更大的开销。
根据Attention的计算区域,可以分成以下几种:
- Soft-Attention/Global Attention:这是比较常见的Attention方式,对所有key求权重概率,每个key都有一个对应的权重,是一种全局的计算方式(也可以叫Global Attention).
- Hard-Attention:这种方式是直接精准定位到某个key,其余key就都不管了,相当于这个key的 概率是1 ,其余key的概率全部是0。因此这种对齐方式要求很高,要求一步到位,如果没有正确对齐, 会带来很大的影响。另一方面,因为不可导,一般需要用强化学习的方法进行训练
- Local-Attention:这种方式其实是以上两种方式的一个折中,对一个窗口区域进行计算。先用 Hard方式定位到某个地方,以这个点为中心可以得到一个窗口区域,在这个小区域内用Soft方式来
算 Attention。
3、小技巧
对实体前后添加特定标识符标明实体位置
采用带约束的正则损失
三、实验结果
compare various model configurations on the SemEval-2010 Task 8 dataset
四、论文结论
1、关键点
不依赖任何其他NLP工具
2、创新点
引入Attention-BiLSTM结构
3、启发点
网格结构完全不依何nlp工具或词法资源,只需要带位置标识的原始文本作为输入。
This model does not rely on NLP tools or lexical resources to get, it uses raw text with position indicators as input.
五、论文代码
1、数据集
1.1 原始数据集
train_file.txt【样本1-8000】
1 "The system as described above has its greatest application in an arrayed <e1>configuration</e1> of antenna <e2>elements</e2>."
Component-Whole(e2,e1)
Comment: Not a collection: there is structure here, organisation.
2 "The <e1>child</e1> was carefully wrapped and bound into the <e2>cradle</e2> by means of a cord."
Other
Comment:
3 "The <e1>author</e1> of a keygen uses a <e2>disassembler</e2> to look at the raw assembly code."
Instrument-Agency(e2,e1)
Comment:
4 "A misty <e1>ridge</e1> uprises from the <e2>surge</e2>."
Other
Comment:
5 "The <e1>student</e1> <e2>association</e2> is the voice of the undergraduate student population of the State University of New York at Buffalo."
Member-Collection(e1,e2)
Comment:
6 "This is the sprawling <e1>complex</e1> that is Peru's largest <e2>producer</e2> of silver."
Other
Comment:
7 "The current view is that the chronic <e1>inflammation</e1> in the distal part of the stomach caused by Helicobacter pylori <e2>infection</e2> results in an increased acid production from the non-infected upper corpus region of the stomach."
Cause-Effect(e2,e1)
Comment:
8 "<e1>People</e1> have been moving back into <e2>downtown</e2>."
Entity-Destination(e1,e2)
Comment:
9 "The <e1>lawsonite</e1> was contained in a <e2>platinum crucible</e2> and the counter-weight was a plastic crucible with metal pieces."
Content-Container(e1,e2)
Comment: prototypical example
10 "The solute was placed inside a beaker and 5 mL of the <e1>solvent</e1> was pipetted into a 25 mL glass <e2>flask</e2> for each trial."
Entity-Destination(e1,e2)
Comment:
......
test_file.txt【样本8001-10717】
8001 "The most common <e1>audits</e1> were about <e2>waste</e2> and recycling."
Message-Topic(e1,e2)
Comment: Assuming an audit = an audit document.
8002 "The <e1>company</e1> fabricates plastic <e2>chairs</e2>."
Product-Producer(e2,e1)
Comment: (a) is satisfied
8003 "The school <e1>master</e1> teaches the lesson with a <e2>stick</e2>."
Instrument-Agency(e2,e1)
Comment:
8004 "The suspect dumped the dead <e1>body</e1> into a local <e2>reservoir</e2>."
Entity-Destination(e1,e2)
Comment:
8005 "Avian <e1>influenza</e1> is an infectious disease of birds caused by type A strains of the influenza <e2>virus</e2>."
Cause-Effect(e2,e1)
Comment:
8006 "The <e1>ear</e1> of the African <e2>elephant</e2> is significantly larger--measuring 183 cm by 114 cm in the bush elephant."
Component-Whole(e1,e2)
Comment:
8007 "A child is told a <e1>lie</e1> for several years by their <e2>parents</e2> before he/she realizes that a Santa Claus does not exist."
Product-Producer(e1,e2)
Comment: (a) is satisfied; negation is outside
8008 "Skype, a free software, allows a <e1>hookup</e1> of multiple computer <e2>users</e2> to join in an online conference call without incurring any telephone costs."
Member-Collection(e2,e1)
Comment:
8009 "The disgusting scene was retaliation against her brother Philip who rents the <e1>room</e1> inside this apartment <e2>house</e2> on Lombard street."
Component-Whole(e1,e2)
Comment:
8010 "This <e1>thesis</e1> defines the <e2>clinical characteristics</e2> of amyloid disease."
Message-Topic(e1,e2)
Comment: may be we could leave clinical out of e2.
1.2 处理后的数据
preprocess.py
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# @Version : Python 3.6
import json
import re
from nltk.tokenize import word_tokenize
def search_entity(sentence):
e1 = re.findall(r'<e1>(.*)</e1>', sentence)[0]
e2 = re.findall(r'<e2>(.*)</e2>', sentence)[0]
sentence = sentence.replace('<e1>' + e1 + '</e1>', ' <e1> ' + e1 + ' </e1> ', 1)
sentence = sentence.replace('<e2>' + e2 + '</e2>', ' <e2> ' + e2 + ' </e2> ', 1)
sentence = word_tokenize(sentence)
sentence = ' '.join(sentence)
sentence = sentence.replace('< e1 >', '<e1>')
sentence = sentence.replace('< e2 >', '<e2>')
sentence = sentence.replace('< /e1 >', '</e1>')
sentence = sentence.replace('< /e2 >', '</e2>')
sentence = sentence.split()
assert '<e1>' in sentence
assert '<e2>' in sentence
assert '</e1>' in sentence
assert '</e2>' in sentence
return sentence
def convert(path_src, path_des):
with open(path_src, 'r', encoding='utf-8') as fr:
data = fr.readlines()
with open(path_des, 'w', encoding='utf-8') as fw:
for i in range(0, len(data), 4):
id_s, sentence = data[i].strip().split('\t')
sentence = sentence[1:-1]
sentence = search_entity(sentence)
meta = dict(
id=id_s,
relation=data[i+1].strip(),
sentence=sentence,
comment=data[i+2].strip()[8:]
)
json.dump(meta, fw, ensure_ascii=False)
fw.write('\n')
if __name__ == '__main__':
path_train = './SemEval2010_task8_all_data/SemEval2010_task8_training/TRAIN_FILE.TXT'
path_test = './SemEval2010_task8_all_data/SemEval2010_task8_testing_keys/TEST_FILE_FULL.TXT'
convert(path_train, 'train.json')
convert(path_test, 'test.json')
train.json
{"id": "1", "relation": "Component-Whole(e2,e1)", "sentence": ["The", "system", "as", "described", "above", "has", "its", "greatest", "application", "in", "an", "arrayed", "<e1>", "configuration", "</e1>", "of", "antenna", "<e2>", "elements", "</e2>", "."], "comment": " Not a collection: there is structure here, organisation."}
{"id": "2", "relation": "Other", "sentence": ["The", "<e1>", "child", "</e1>", "was", "carefully", "wrapped", "and", "bound", "into", "the", "<e2>", "cradle", "</e2>", "by", "means", "of", "a", "cord", "."], "comment": ""}
{"id": "3", "relation": "Instrument-Agency(e2,e1)", "sentence": ["The", "<e1>", "author", "</e1>", "of", "a", "keygen", "uses", "a", "<e2>", "disassembler", "</e2>", "to", "look", "at", "the", "raw", "assembly", "code", "."], "comment": ""}
{"id": "4", "relation": "Other", "sentence": ["A", "misty", "<e1>", "ridge", "</e1>", "uprises", "from", "the", "<e2>", "surge", "</e2>", "."], "comment": ""}
{"id": "5", "relation": "Member-Collection(e1,e2)", "sentence": ["The", "<e1>", "student", "</e1>", "<e2>", "association", "</e2>", "is", "the", "voice", "of", "the", "undergraduate", "student", "population", "of", "the", "State", "University", "of", "New", "York", "at", "Buffalo", "."], "comment": ""}
......
test.json
{"id": "8001", "relation": "Message-Topic(e1,e2)", "sentence": ["The", "most", "common", "<e1>", "audits", "</e1>", "were", "about", "<e2>", "waste", "</e2>", "and", "recycling", "."], "comment": " Assuming an audit = an audit document."}
{"id": "8002", "relation": "Product-Producer(e2,e1)", "sentence": ["The", "<e1>", "company", "</e1>", "fabricates", "plastic", "<e2>", "chairs", "</e2>", "."], "comment": " (a) is satisfied"}
{"id": "8003", "relation": "Instrument-Agency(e2,e1)", "sentence": ["The", "school", "<e1>", "master", "</e1>", "teaches", "the", "lesson", "with", "a", "<e2>", "stick", "</e2>", "."], "comment": ""}
{"id": "8004", "relation": "Entity-Destination(e1,e2)", "sentence": ["The", "suspect", "dumped", "the", "dead", "<e1>", "body", "</e1>", "into", "a", "local", "<e2>", "reservoir", "</e2>", "."], "comment": ""}
{"id": "8005", "relation": "Cause-Effect(e2,e1)", "sentence": ["Avian", "<e1>", "influenza", "</e1>", "is", "an", "infectious", "disease", "of", "birds", "caused", "by", "type", "A", "strains", "of", "the", "influenza", "<e2>", "virus", "</e2>", "."], "comment": ""}
......
1.3 relation2id
Other 0
Cause-Effect(e1,e2) 1
Cause-Effect(e2,e1) 2
Component-Whole(e1,e2) 3
Component-Whole(e2,e1) 4
Content-Container(e1,e2) 5
Content-Container(e2,e1) 6
Entity-Destination(e1,e2) 7
Entity-Destination(e2,e1) 8
Entity-Origin(e1,e2) 9
Entity-Origin(e2,e1) 10
Instrument-Agency(e1,e2) 11
Instrument-Agency(e2,e1) 12
Member-Collection(e1,e2) 13
Member-Collection(e2,e1) 14
Message-Topic(e1,e2) 15
Message-Topic(e2,e1) 16
Product-Producer(e1,e2) 17
Product-Producer(e2,e1) 18
2、预训练词向量:静态词向量HLBL
hlbl-embeddings-scaled.EMBEDDING_SIZE=50
*UNKNOWN* -0.166038776479 0.104395984608 0.163119732357 0.0899594154863 -0.0192271099805 -0.0417631572501 -0.0163376687927 0.0357616216019 0.0536077591673 0.0127688536503 -0.00284508433021 -0.0626207031228 -0.0379452734015 -0.103548297666 0.0381169119981 0.00199421074321 -0.0474636488659 -0.0127526851513 0.016404178535 -0.12759853361 -0.0292937037717 -0.0512566352549 0.0233097445983 0.0360505083995 0.00229317984472 -0.0771565284227 0.0071461584378 -0.051608090196 -0.0267547654304 0.0492994451068 -0.0531630844999 0.00787191810391 0.082280106873 0.066908641868 -0.0283930612982 0.216840166248 0.164923151267 0.00188498983723 0.0328679039324 -0.00175432516758 0.0614261774935 0.0987773071377 0.0548423375506 -0.0307057922059 0.053074241476 0.04982054279 -0.0572485864016 0.132236444766 -0.0379717035014 -0.120915939814
the -0.0841015569168 0.145263825738 0.116945121935 -0.0754618634155 0.17901499611 -0.000652852605208 -0.0713783879233 0.207273704502 0.060711721477 0.0366727701165 -0.0269791566731 -0.156993473526 -0.0393947453024 0.00749161628231 -0.332851634057 -0.1708430781 -0.275163605231 -0.266592614101 0.43349041466 -0.00779248211778 0.031101796379 -0.0257114150838 0.174856713352 -0.0543054233622 -0.0846669459476 -0.006234398456 0.00414488584462 0.119738648443 -0.0914876936952 -0.317381121871 -0.27471439742 0.234269597998 0.170305945138 -0.0282815073325 -0.10127814458 0.156451476203 0.154703520781 -0.0014827085612 0.164287521114 0.0328582913203 0.0356570354049 -0.190254406793 -0.112029936115 -0.198875312619 0.00102875631152 -0.00161517169984 -0.125210890327 0.196903181061 -0.112017915766 -0.00838804375065
. -0.0875932389444 -0.0586365253633 0.0729727126603 0.32072000431 0.0745620569276 -0.0494709138174 0.208708067552 -0.025035364294 -0.197531050237 0.177318202028 0.297077745222 -0.0256369072571 0.182364658364 0.189089099105 0.0589179494006 -0.0627276310572 0.0682898379459 0.241161712515 0.253510796291 -0.0325139691451 -0.0129081882483 -0.083367340352 0.0276167362372 -0.00757124183183 -0.0905801885623 0.305015208385 0.0755474920504 -0.00516459185438 -0.0412876867803 0.105047372601 -0.718674456034 0.184682477295 0.232732814491 0.0929975692214 0.0999329447708 -0.0968008990987 0.421525505372 -0.136460066398 -0.323294448817 0.118318915141 0.415411774103 -0.135770867168 0.0404792691614 0.264279769529 -0.133076243622 0.195087919022 -0.087589323012 0.0335223022065 -0.0365650611956 -0.0163760300203
, -0.023019838485 0.277215570968 0.241932261453 -0.105403438907 0.247316949736 0.0859618436243 -0.0130132156599 0.123988163629 -0.150741462418 0.129993766762 0.0766431623839 0.0547135456598 0.187342182554 0.176303102861 -0.121401723217 0.0458278230666 0.0339804870854 -0.0619606057248 0.0514787739809 0.00732501266557 0.0879996990484 -0.369288823679 0.235222707122 -0.0528783055204 0.0121891472663 -0.165169815904 -0.136829953355 -0.0750751223049 -0.0503433833321 0.0782539868365 -0.400940778018 -0.099745222007 -0.152448498545 -0.0815002789835 -0.010575616616 0.331604536668 -0.0124179474775 0.00173559407939 -0.230971231526 0.0162523457081 0.213848645598 0.184698023693 0.158368229826 0.0975422545404 -0.0307127563081 0.093420146492 -0.0377856184872 -0.0181716170654 0.43322993915 -0.113289957059
to 0.134693667961 0.392203653086 0.0346151199225 0.135354475458 0.0719918082372 0.118667933013 -0.0698386234679 -0.0139927084407 0.144452931939 0.0383223273458 -0.0491954394553 -0.126435975874 0.23979196724 -0.186550477314 0.0602616605691 -0.0875395769807 0.0788848675161 0.132691898026 0.155618778336 0.00680378469567 -0.126513561203 -0.436124771467 0.132675129426 -0.0946286638801 0.0986847070674 -0.354397304845 -0.196909463175 -0.0911408611189 0.134975690877 0.0625931974859 0.0108112360985 -0.107933544401 -0.166545488854 0.0137397678012 -0.0268394211932 -0.260328038765 0.0745185746772 0.020864049205 0.133485534344 -0.0479098207297 0.145382061477 -0.116284346216 0.0822848147919 -0.00621959258902 0.0135679910959 -0.0723116375013 -0.422793539068 0.144456402991 -0.119019192402 0.0659297394103
......
3、config.py
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# @Version : Python 3.6
import argparse
import torch
import os
import random
import json
import numpy as np
class Config(object):
def __init__(self):
# get init config
args = self.__get_config()
for key in args.__dict__:
setattr(self, key, args.__dict__[key])
# select device
self.device = None
if self.cuda >= 0 and torch.cuda.is_available():
self.device = torch.device('cuda:{}'.format(self.cuda))
else:
self.device = torch.device('cpu')
# determine the model name and model dir
if self.model_name is None:
self.model_name = 'Att_BLSTM'
self.model_dir = os.path.join(self.output_dir, self.model_name)
if not os.path.exists(self.model_dir):
os.makedirs(self.model_dir)
# backup data
self.__config_backup(args)
# set the random seed
self.__set_seed(self.seed)
def __get_config(self):
parser = argparse.ArgumentParser()
parser.description = 'config for models'
# several key selective parameters
parser.add_argument('--data_dir', type=str,
default='./data',
help='dir to load data')
parser.add_argument('--output_dir', type=str,
default='./output',
help='dir to save output')
# word embedding
parser.add_argument('--embedding_path', type=str,
default='./embedding/glove.6B.100d.txt',
help='pre_trained word embedding')
parser.add_argument('--word_dim', type=int,
default=100,
help='dimension of word embedding')
# train settings
parser.add_argument('--model_name', type=str,
default=None,
help='model name')
parser.add_argument('--mode', type=int,
default=1,
choices=[0, 1],
help='running mode: 1 for training; otherwise testing')
parser.add_argument('--seed', type=int,
default=5782,
help='random seed')
parser.add_argument('--cuda', type=int,
default=0,
help='num of gpu device, if -1, select cpu')
parser.add_argument('--epoch', type=int,
default=30,
help='max epoches during training')
# hyper parameters
parser.add_argument('--batch_size', type=int,
default=10,
help='batch size')
parser.add_argument('--lr', type=float,
default=1.0,
help='learning rate')
parser.add_argument('--max_len', type=int,
default=100,
help='max length of sentence')
parser.add_argument('--emb_dropout', type=float,
default=0.3,
help='the possiblity of dropout in embedding layer')
parser.add_argument('--lstm_dropout', type=float,
default=0.3,
help='the possiblity of dropout in (Bi)LSTM layer')
parser.add_argument('--linear_dropout', type=float,
default=0.5,
help='the possiblity of dropout in liner layer')
parser.add_argument('--hidden_size', type=int,
default=100,
help='the dimension of hidden units in (Bi)LSTM layer')
parser.add_argument('--layers_num', type=int,
default=1,
help='num of RNN layers')
parser.add_argument('--L2_decay', type=float, default=1e-5,
help='L2 weight decay')
args = parser.parse_args()
return args
def __set_seed(self, seed=1234):
os.environ['PYTHONHASHSEED'] = '{}'.format(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed) # set seed for cpu
torch.cuda.manual_seed(seed) # set seed for current gpu
torch.cuda.manual_seed_all(seed) # set seed for all gpu
def __config_backup(self, args):
config_backup_path = os.path.join(self.model_dir, 'config.json')
with open(config_backup_path, 'w', encoding='utf-8') as fw:
json.dump(vars(args), fw, ensure_ascii=False)
def print_config(self):
for key in self.__dict__:
print(key, end=' = ')
print(self.__dict__[key])
if __name__ == '__main__':
config = Config()
config.print_config()
4、model.py
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# @Version : Python 3.6
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
class Att_BLSTM(nn.Module):
def __init__(self, word_vec, class_num, config):
super().__init__()
self.word_vec = word_vec
self.class_num = class_num
# hyper parameters and others
self.max_len = config.max_len
self.word_dim = config.word_dim
self.hidden_size = config.hidden_size
self.layers_num = config.layers_num
self.emb_dropout_value = config.emb_dropout
self.lstm_dropout_value = config.lstm_dropout
self.linear_dropout_value = config.linear_dropout
# net structures and operations
self.word_embedding = nn.Embedding.from_pretrained(
embeddings=self.word_vec,
freeze=False,
)
self.lstm = nn.LSTM(
input_size=self.word_dim,
hidden_size=self.hidden_size,
num_layers=self.layers_num,
bias=True,
batch_first=True,
dropout=0,
bidirectional=True,
)
self.tanh = nn.Tanh()
self.emb_dropout = nn.Dropout(self.emb_dropout_value)
self.lstm_dropout = nn.Dropout(self.lstm_dropout_value)
self.linear_dropout = nn.Dropout(self.linear_dropout_value)
self.att_weight = nn.Parameter(torch.randn(1, self.hidden_size, 1))
self.dense = nn.Linear(
in_features=self.hidden_size,
out_features=self.class_num,
bias=True
)
# initialize weight
init.xavier_normal_(self.dense.weight)
init.constant_(self.dense.bias, 0.)
def lstm_layer(self, x, mask):
lengths = torch.sum(mask.gt(0), dim=-1)
x = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
h, (_, _) = self.lstm(x)
h, _ = pad_packed_sequence(h, batch_first=True, padding_value=0.0, total_length=self.max_len)
h = h.view(-1, self.max_len, 2, self.hidden_size)
h = torch.sum(h, dim=2) # B*L*H
return h
def attention_layer(self, h, mask):
att_weight = self.att_weight.expand(mask.shape[0], -1, -1) # B*H*1
att_score = torch.bmm(self.tanh(h), att_weight) # B*L*H * B*H*1 -> B*L*1
# mask, remove the effect of 'PAD'
mask = mask.unsqueeze(dim=-1) # B*L*1
att_score = att_score.masked_fill(mask.eq(0), float('-inf')) # B*L*1
att_weight = F.softmax(att_score, dim=1) # B*L*1
reps = torch.bmm(h.transpose(1, 2), att_weight).squeeze(dim=-1) # B*H*L * B*L*1 -> B*H*1 -> B*H
reps = self.tanh(reps) # B*H
return reps
def forward(self, data):
token = data[:, 0, :].view(-1, self.max_len)
mask = data[:, 1, :].view(-1, self.max_len)
emb = self.word_embedding(token) # B*L*word_dim
emb = self.emb_dropout(emb)
h = self.lstm_layer(emb, mask) # B*L*H
h = self.lstm_dropout(h)
reps = self.attention_layer(h, mask) # B*reps
reps = self.linear_dropout(reps)
logits = self.dense(reps)
return logits
5、train_or_test.py
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# @Version : Python 3.6
import os
import torch
import torch.nn as nn
import torch.optim as optim
from config import Config
from utils import WordEmbeddingLoader, RelationLoader, SemEvalDataLoader
from model import Att_BLSTM
from evaluate import Eval
def print_result(predict_label, id2rel, start_idx=8001):
with open('predicted_result.txt', 'w', encoding='utf-8') as fw:
for i in range(0, predict_label.shape[0]):
fw.write('{}\t{}\n'.format(start_idx+i, id2rel[int(predict_label[i])]))
def train(model, criterion, loader, config):
train_loader, dev_loader, _ = loader
optimizer = optim.Adadelta(model.parameters(), lr=config.lr, weight_decay=config.L2_decay)
print(model)
print('traning model parameters:')
for name, param in model.named_parameters():
if param.requires_grad:
print('%s : %s' % (name, str(param.data.shape)))
print('--------------------------------------')
print('start to train the model ...')
eval_tool = Eval(config)
min_f1 = -float('inf')
for epoch in range(1, config.epoch+1):
for step, (data, label) in enumerate(train_loader):
model.train()
data = data.to(config.device)
label = label.to(config.device)
optimizer.zero_grad()
logits = model(data)
loss = criterion(logits, label)
loss.backward()
nn.utils.clip_grad_value_(model.parameters(), clip_value=5)
optimizer.step()
_, train_loss, _ = eval_tool.evaluate(model, criterion, train_loader)
f1, dev_loss, _ = eval_tool.evaluate(model, criterion, dev_loader)
print('[%03d] train_loss: %.3f | dev_loss: %.3f | micro f1 on dev: %.4f'
% (epoch, train_loss, dev_loss, f1), end=' ')
if f1 > min_f1:
min_f1 = f1
torch.save(model.state_dict(), os.path.join(config.model_dir, 'model.pkl'))
print('>>> save models!')
else:
print()
def test(model, criterion, loader, config):
print('--------------------------------------')
print('start test ...')
_, _, test_loader = loader
model.load_state_dict(torch.load(os.path.join(config.model_dir, 'model.pkl')))
eval_tool = Eval(config)
f1, test_loss, predict_label = eval_tool.evaluate(model, criterion, test_loader)
print('test_loss: %.3f | micro f1 on test: %.4f' % (test_loss, f1))
return predict_label
if __name__ == '__main__':
config = Config()
print('--------------------------------------')
print('some config:')
config.print_config()
print('--------------------------------------')
print('start to load data ...')
word2id, word_vec = WordEmbeddingLoader(config).load_embedding()
rel2id, id2rel, class_num = RelationLoader(config).get_relation()
loader = SemEvalDataLoader(rel2id, word2id, config)
train_loader, dev_loader = None, None
if config.mode == 1: # train mode
train_loader = loader.get_train()
dev_loader = loader.get_dev()
test_loader = loader.get_test()
loader = [train_loader, dev_loader, test_loader]
print('finish!')
print('--------------------------------------')
model = Att_BLSTM(word_vec=word_vec, class_num=class_num, config=config)
model = model.to(config.device)
criterion = nn.CrossEntropyLoss()
if config.mode == 1: # train mode
train(model, criterion, loader, config)
predict_label = test(model, criterion, loader, config)
print_result(predict_label, id2rel)
6、evaluate.py
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# @Version : Python 3.6
import numpy as np
import torch
def semeval_scorer(predict_label, true_label, class_num=10):
import math
assert true_label.shape[0] == predict_label.shape[0]
confusion_matrix = np.zeros(shape=[class_num, class_num], dtype=np.float32)
xDIRx = np.zeros(shape=[class_num], dtype=np.float32)
for i in range(true_label.shape[0]):
true_idx = math.ceil(true_label[i]/2)
predict_idx = math.ceil(predict_label[i]/2)
if true_label[i] == predict_label[i]:
confusion_matrix[predict_idx][true_idx] += 1
else:
if true_idx == predict_idx:
xDIRx[predict_idx] += 1
else:
confusion_matrix[predict_idx][true_idx] += 1
col_sum = np.sum(confusion_matrix, axis=0).reshape(-1)
row_sum = np.sum(confusion_matrix, axis=1).reshape(-1)
f1 = np.zeros(shape=[class_num], dtype=np.float32)
for i in range(0, class_num): # ignore the 'Other'
try:
p = float(confusion_matrix[i][i]) / float(col_sum[i] + xDIRx[i])
r = float(confusion_matrix[i][i]) / float(row_sum[i] + xDIRx[i])
f1[i] = (2 * p * r / (p + r))
except:
pass
actual_class = 0
total_f1 = 0.0
for i in range(1, class_num):
if f1[i] > 0.0: # classes that not in the predict label are not considered
actual_class += 1
total_f1 += f1[i]
try:
macro_f1 = total_f1 / actual_class
except:
macro_f1 = 0.0
return macro_f1
class Eval(object):
def __init__(self, config):
self.device = config.device
def evaluate(self, model, criterion, data_loader):
predict_label = []
true_label = []
total_loss = 0.0
with torch.no_grad():
model.eval()
for _, (data, label) in enumerate(data_loader):
data = data.to(self.device)
label = label.to(self.device)
logits = model(data)
loss = criterion(logits, label)
total_loss += loss.item() * logits.shape[0]
_, pred = torch.max(logits, dim=1) # replace softmax with max function, same impacts
pred = pred.cpu().detach().numpy().reshape((-1, 1))
label = label.cpu().detach().numpy().reshape((-1, 1))
predict_label.append(pred)
true_label.append(label)
predict_label = np.concatenate(predict_label, axis=0).reshape(-1).astype(np.int64)
true_label = np.concatenate(true_label, axis=0).reshape(-1).astype(np.int64)
eval_loss = total_loss / predict_label.shape[0]
f1 = semeval_scorer(predict_label, true_label)
return f1, eval_loss, predict_label
7、util.py
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# @Version : Python 3.6
import os
import json
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
class WordEmbeddingLoader(object):
"""
A loader for pre-trained word embedding
"""
def __init__(self, config):
self.path_word = config.embedding_path # path of pre-trained word embedding
self.word_dim = config.word_dim # dimension of word embedding
def load_embedding(self):
word2id = dict() # word to wordID
word_vec = list() # wordID to word embedding
word2id['PAD'] = len(word2id) # PAD character
word2id['UNK'] = len(word2id) # out of vocabulary
word2id['<e1>'] = len(word2id)
word2id['<e2>'] = len(word2id)
word2id['</e1>'] = len(word2id)
word2id['</e2>'] = len(word2id)
with open(self.path_word, 'r', encoding='utf-8') as fr:
for line in fr:
line = line.strip().split()
if len(line) != self.word_dim + 1:
continue
word2id[line[0]] = len(word2id)
word_vec.append(np.asarray(line[1:], dtype=np.float32))
word_vec = np.stack(word_vec)
vec_mean, vec_std = word_vec.mean(), word_vec.std()
special_emb = np.random.normal(vec_mean, vec_std, (6, self.word_dim))
special_emb[0] = 0 # <pad> is initialize as zero
word_vec = np.concatenate((special_emb, word_vec), axis=0)
word_vec = word_vec.astype(np.float32).reshape(-1, self.word_dim)
word_vec = torch.from_numpy(word_vec)
return word2id, word_vec
class RelationLoader(object):
def __init__(self, config):
self.data_dir = config.data_dir
def __load_relation(self):
relation_file = os.path.join(self.data_dir, 'relation2id.txt')
rel2id = {}
id2rel = {}
with open(relation_file, 'r', encoding='utf-8') as fr:
for line in fr:
relation, id_s = line.strip().split()
id_d = int(id_s)
rel2id[relation] = id_d
id2rel[id_d] = relation
return rel2id, id2rel, len(rel2id)
def get_relation(self):
return self.__load_relation()
class SemEvalDateset(Dataset):
def __init__(self, filename, rel2id, word2id, config):
self.filename = filename
self.rel2id = rel2id
self.word2id = word2id
self.max_len = config.max_len
self.data_dir = config.data_dir
self.dataset, self.label = self.__load_data()
def __symbolize_sentence(self, sentence):
"""
Args:
sentence (list)
"""
mask = [1] * len(sentence)
words = []
length = min(self.max_len, len(sentence))
mask = mask[:length]
for i in range(length):
words.append(self.word2id.get(sentence[i].lower(), self.word2id['UNK']))
if length < self.max_len:
for i in range(length, self.max_len):
mask.append(0) # 'PAD' mask is zero
words.append(self.word2id['PAD'])
unit = np.asarray([words, mask], dtype=np.int64)
unit = np.reshape(unit, newshape=(1, 2, self.max_len))
return unit
def __load_data(self):
path_data_file = os.path.join(self.data_dir, self.filename)
data = []
labels = []
with open(path_data_file, 'r', encoding='utf-8') as fr:
for line in fr:
line = json.loads(line.strip())
label = line['relation']
sentence = line['sentence']
label_idx = self.rel2id[label]
one_sentence = self.__symbolize_sentence(sentence)
data.append(one_sentence)
labels.append(label_idx)
return data, labels
def __getitem__(self, index):
data = self.dataset[index]
label = self.label[index]
return data, label
def __len__(self):
return len(self.label)
class SemEvalDataLoader(object):
def __init__(self, rel2id, word2id, config):
self.rel2id = rel2id
self.word2id = word2id
self.config = config
def __collate_fn(self, batch):
data, label = zip(*batch) # unzip the batch data
data = list(data)
label = list(label)
data = torch.from_numpy(np.concatenate(data, axis=0))
label = torch.from_numpy(np.asarray(label, dtype=np.int64))
return data, label
def __get_data(self, filename, shuffle=False):
dataset = SemEvalDateset(filename, self.rel2id, self.word2id, self.config)
loader = DataLoader(
dataset=dataset,
batch_size=self.config.batch_size,
shuffle=shuffle,
num_workers=2,
collate_fn=self.__collate_fn
)
return loader
def get_train(self):
return self.__get_data('train.json', shuffle=True)
def get_dev(self):
return self.__get_data('test.json', shuffle=False)
def get_test(self):
return self.__get_data('test.json', shuffle=False)
if __name__ == '__main__':
from config import Config
config = Config()
word2id, word_vec = WordEmbeddingLoader(config).load_embedding()
rel2id, id2rel, class_num = RelationLoader(config).get_relation()
loader = SemEvalDataLoader(rel2id, word2id, config)
test_loader = loader.get_train()
for step, (data, label) in enumerate(test_loader):
print(type(data), data.shape)
print(type(label), label.shape)
break