READING AND ANSWERING GIVEN REASONING PATHS
reader model
多任务阅读器模型
多任务阅读器模型
-
阅读理解任务
使用BERT从推理路径中提取答案范围。
-
对推理路径重排序
使用Bert模型对应于CLS标识符位的输出判断推理路径包括答案的概率。根据概率对推理路径重新排序。
P ( E ∣ q ) = σ ( w n ⋅ u E ) s . t . u E = B E R T [ C L S ] ( q , E ) ∈ R D P(E|q) = σ(w_n· u_E) \ \ s.t. \ \ u_E= BERT_{[CLS]}(q, E) ∈ \R^D P(E∣q)=σ(wn⋅uE) s.t. uE=BERT[CLS](q,E)∈RD
w n ∈ R D w_n∈ R^D wn∈RD:权重向量
P ( E ∣ q ) P(E|q) P(E∣q):推理路径E的概率
E b e s t = a r g m a x E ∈ E P ( E ∣ q ) E_{best}=\underset{E∈E} {arg\ max } \ P(E|q) Ebest=E∈Earg max P(E∣q)
E b e s t E_{best} Ebest:最佳路径
S r e a d = a r g m a x i , j , i ≤ j P i s t a r t P j e n d S_{read}= \underset{i,j, i≤j}{arg \ max}\ P^{start}_i P^{end}_j Sread=i,j,i≤jarg max PistartPjend
S r e a d S_{read} Sread:正确答案的范围
P i s t a r t , P j e n d P^{start}_i,P^{end}_j Pistart,Pjend表示 E b e s t E_{best} Ebest中第i个token和第j个token分别为开始位置和结束位置的概率
- 增加负例数据:
为了训练我们的读者模型来区分相关和不相关的推理路径,我们对原始训练数据进行了补充,并附加了其他负面示例来模拟不完全的证据。
- 损失函数
目标是跨度预测和重新排序任务的交叉熵损失之和。 问题q及其候选证据E的损失:
L
r
e
a
d
=
L
s
p
a
n
+
L
n
o
_
a
n
s
w
e
r
=
(
−
l
o
g
P
y
s
t
a
r
t
s
t
a
r
t
−
l
o
g
P
y
e
n
d
e
n
d
)
−
l
o
g
P
r
L_{read}= L_{span}+ L_{no\_answer}= (− log P^{start}_{y^{start}} − log P^{end}_{y^{end}}) − log P^r
Lread=Lspan+Lno_answer=(−logPystartstart−logPyendend)−logPr
y s t a r t , y e n d y^{start}, y^{end} ystart,yend是 ground-truth的开始和结束。
L n o _ a n s w e r L_{no\_answer} Lno_answer:重新re-ranking model的损失,辨别没有答案的失真推理路径。
P r P^r Pr: if E is the ground-truth evidence; P r = P ( E ∣ q ) P^r= P(E|q) Pr=P(E∣q),otherwise P r = 1 − P ( E ∣ q ) P^r= 1 − P(E|q) Pr=1−P(E∣q).
屏蔽了负样本跨度损失,以避免对跨度预测产生意外影响。
代码实现
优化器:
optimizer = BertAdam(optimizer_grouped_parameters,
lr=args.learning_rate,
warmup=args.warmup_proportion,
t_total=num_train_optimization_steps)
使用dataloader加载训练数据
train_sampler = DistributedSampler(train_data)
train_dataloader = DataLoader(
train_data, sampler=train_sampler, batch_size=args.train_batch_size)
显示进度
for _ in trange(int(args.num_train_epochs), desc="Epoch"):
Epoch: 100%|██████████| 3/3 [00:03<00:00, 1.00s/it]
创建模型
model = BertForQuestionAnsweringConfidence.from_pretrained(args.bert_model,
cache_dir=os.path.join(
str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank)),
num_labels=4,
no_masking=args.no_masking,
lambda_scale=args.lambda_scale)
损失函数
loss = model(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask,
start_positions=start_positions, end_positions=end_positions, switch_list=switches)
L s p a n + L n o _ a n s w e r L_{span}+L_{no\_answer} Lspan+Lno_answer
span_mask = (switch_list == 0).type(torch.FloatTensor).cuda()
start_losses = loss_fct(
start_logits, start_positions) * span_mask
end_losses = loss_fct(end_logits, end_positions) * span_mask
collections.OrderedDict()#实现对字典元素排序
collections.namedtuple()#类似于结构体的用法
json.dumps() #将python对象编码成Json字符串
仅使用可以找大答案的问题
actual_text = " ".join(
doc_tokens[start_position:(end_position + 1)])
cleaned_answer_text = " ".join(
whitespace_tokenize(orig_answer_text))
if actual_text.find(cleaned_answer_text) == -1:
logger.warning("Could not find answer: '%s' vs. '%s'",
actual_text, cleaned_answer_text)
continue
过滤掉过长的问题
if len(orig_answer_text.split()) > max_answer_len: logger.info( "Omitting a long answer: '%s'", orig_answer_text) continue
num_train_optimization_steps // int(save_chunk) #结果取整数
将答案范围设置为与a匹配并首先出现的字符串。
for i in range(len(index_and_score)): if i >= n_best_size: break best_indexes.append(index_and_score[i][0])
BRET输入示例
tokens: [CLS] this singer of a rather blu ##ster ##y day also voiced what hedge ##hog ? [SEP] " a rather blu ##ster ##y day " is a w ##him ##sic ##al song from the walt disney musical film feature ##tte , " winnie the po ##oh and the blu ##ster ##y day " . it was written by robert & richard sherman and sung by jim cummings as " po ##oh " . james jonah cummings ( born november 3 , 1952 ) is an american voice actor and singer , who has appeared in almost 400 roles . he is known for vo ##icing the title character from " dark ##wing duck " , dr . robot ##nik from " sonic the hedge ##hog " , and pete . his other characters include winnie the po ##oh , ti ##gger , and the tasmanian devil . he has performed in numerous disney and dream ##works animation ##s including " ala ##ddin " , " the lion king " , " bal ##to " , " ant ##z " , " the road to el dora ##do " , " sh ##rek " , and " the princess and the frog " . he has also provided voice - over work for video games , such as " ice ##wind dale " , " fallout " , " " , " bald ##ur ' s gate " , " mass effect 2 " , " " , " " , " " , and " sp ##lat ##ter ##house " . [SEP]
输入长度不足补0
while len(input_ids) < max_seq_length: input_ids.append(pad_token) input_mask.append(0 if mask_padding_with_zero else 1) segment_ids.append(pad_token_segment_id) p_mask.append(1)
示例
input_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
实验
下载程序
!git clone https://github.com/AkariAsai/learning_to_retrieve_reasoning_paths.git%cd /content/learning_to_retrieve_reasoning_paths!pip install -r requirements.txt
下载训练数据
%cd /content/learning_to_retrieve_reasoning_paths!mkdir data%cd data!mkdir hotpot%cd hotpot!gdown https://drive.google.com/uc?id=1_a8KliAHKIwrYRrHgHOlzM0Jon3AqZLs!mv hotpot_reader_train_data.json.json____ hotpot_reader_train_data.json!gdown https://drive.google.com/uc?id=1R4exuPDaV2yD18xUBsnNyQpXwn0ty5pc!mv nq_reader_train_data_public.json.json____ nq_reader_train_data_public.json!gdown https://drive.google.com/uc?id=1FB5gB9aM8rmbpIwYf-1o6lmxMYQsg_rP!mv squad_reader_train_data.json.json____ squad_reader_train_data.json!gdown https://drive.google.com/uc?id=1MysthH2TRYoJcK_eLOueoLeYR42T-JhB!ls
训练模型
数据集均采用SQuAD v.2 format
使用hotpot_dev_squad_v2.0_format.json训练
hotpot_reader_train_data.json太大训练不出来,这里用hotpot_dev_squad_v2.0_format.json只是为了体验下训练的过程!
%cd /content/learning_to_retrieve_reasoning_paths/reader/!python run_reader_confidence.py \--bert_model bert-base-uncased \--output_dir output_hotpot_bert_base \--train_file /content/learning_to_retrieve_reasoning_paths/data/hotpot/hotpot_dev_squad_v2.0_format.json \--predict_file /content/learning_to_retrieve_reasoning_paths/data/hotpot/hotpot_dev_squad_v2.0_format.json \--max_seq_length 384 \--do_train \--do_predict \--do_lower_case \--version_2_with_negative \--train_batch_size 16
–train_batch_size 根据显卡内存调整
–version_2_with_negative 使用负例数据训练
仅训练
%cd /content/learning_to_retrieve_reasoning_paths/reader/!python run_reader_confidence.py \--bert_model bert-base-uncased \--output_dir output_hotpot_bert_base \--train_file /content/learning_to_retrieve_reasoning_paths/data/hotpot/hotpot_dev_squad_v2.0_format.json \--max_seq_length 384 \--do_train \--do_lower_case \--version_2_with_negative \--train_batch_size 16
仅预测
%cd /content/learning_to_retrieve_reasoning_paths/reader/!python run_reader_confidence.py \--bert_model bert-base-uncased \--output_dir output_hotpot_bert_base \--predict_file /content/learning_to_retrieve_reasoning_paths/data/hotpot/hotpot_dev_squad_v2.0_format.json \--max_seq_length 384 \--do_predict \--do_lower_case \--version_2_with_negative \--train_batch_size 16
评估
下载评估数据集
%cd /content/learning_to_retrieve_reasoning_paths!mkdir data%cd data!mkdir hotpot%cd hotpot!gdown https://drive.google.com/uc?id=1MysthH2TRYoJcK_eLOueoLeYR42T-JhB!ls
评估模型
predictions.json 为模型预测后自动生成
%cd /content/learning_to_retrieve_reasoning_paths/reader/!wget https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/!mv index.html evaluate-v2.0.py!python evaluate-v2.0.py \/content/learning_to_retrieve_reasoning_paths/data/hotpot/hotpot_dev_squad_v2.0_format.json \/content/learning_to_retrieve_reasoning_paths/reader/output_hotpot_bert_base/predictions.json
评估数据:/content/learning_to_retrieve_reasoning_paths/data/hotpot/hotpot_dev_squad_v2.0_format.json