【代码实现】READING AND ANSWERING GIVEN REASONING PATHS

READING AND ANSWERING GIVEN REASONING PATHS

reader model

多任务阅读器模型

多任务阅读器模型

  • 阅读理解任务

    使用BERT从推理路径中提取答案范围。

  • 对推理路径重排序

    使用Bert模型对应于CLS标识符位的输出判断推理路径包括答案的概率。根据概率对推理路径重新排序。

P ( E ∣ q ) = σ ( w n ⋅ u E )    s . t .    u E = B E R T [ C L S ] ( q , E ) ∈ R D P(E|q) = σ(w_n· u_E) \ \ s.t. \ \ u_E= BERT_{[CLS]}(q, E) ∈ \R^D P(Eq)=σ(wnuE)  s.t.  uE=BERT[CLS](q,E)RD

w n ∈ R D w_n∈ R^D wnRD:权重向量

P ( E ∣ q ) P(E|q) P(Eq):推理路径E的概率

E b e s t = a r g   m a x E ∈ E   P ( E ∣ q ) E_{best}=\underset{E∈E} {arg\ max } \ P(E|q) Ebest=EEarg max P(Eq)

E b e s t E_{best} Ebest:最佳路径

S r e a d = a r g   m a x i , j , i ≤ j   P i s t a r t P j e n d S_{read}= \underset{i,j, i≤j}{arg \ max}\ P^{start}_i P^{end}_j Sread=i,j,ijarg max PistartPjend

S r e a d S_{read} Sread:正确答案的范围

P i s t a r t , P j e n d P^{start}_i,P^{end}_j PistartPjend表示 E b e s t E_{best} Ebest中第i个token和第j个token分别为开始位置和结束位置的概率

  • 增加负例数据:

​ 为了训练我们的读者模型来区分相关和不相关的推理路径,我们对原始训练数据进行了补充,并附加了其他负面示例来模拟不完全的证据。

  • 损失函数

目标是跨度预测和重新排序任务的交叉熵损失之和。 问题q及其候选证据E的损失:
L r e a d = L s p a n + L n o _ a n s w e r = ( − l o g P y s t a r t s t a r t − l o g P y e n d e n d ) − l o g P r L_{read}= L_{span}+ L_{no\_answer}= (− log P^{start}_{y^{start}} − log P^{end}_{y^{end}}) − log P^r Lread=Lspan+Lno_answer=(logPystartstartlogPyendend)logPr

y s t a r t , y e n d y^{start}, y^{end} ystart,yend是 ground-truth的开始和结束。

L n o _ a n s w e r L_{no\_answer} Lno_answer:重新re-ranking model的损失,辨别没有答案的失真推理路径。

P r P^r Pr: if E is the ground-truth evidence; P r = P ( E ∣ q ) P^r= P(E|q) Pr=P(Eq),otherwise P r = 1 − P ( E ∣ q ) P^r= 1 − P(E|q) Pr=1P(Eq).

屏蔽了负样本跨度损失,以避免对跨度预测产生意外影响。

代码实现

优化器:

optimizer = BertAdam(optimizer_grouped_parameters,
                                 lr=args.learning_rate,
                                 warmup=args.warmup_proportion,
                                 t_total=num_train_optimization_steps)

使用dataloader加载训练数据

train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(
            train_data, sampler=train_sampler, batch_size=args.train_batch_size)

显示进度

for _ in trange(int(args.num_train_epochs), desc="Epoch"):

Epoch: 100%|██████████| 3/3 [00:03<00:00, 1.00s/it]

创建模型

model = BertForQuestionAnsweringConfidence.from_pretrained(args.bert_model,
                                                                   cache_dir=os.path.join(
                                                                       str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank)),
                                                                   num_labels=4,
                                                                   no_masking=args.no_masking,
                                                                   lambda_scale=args.lambda_scale)

损失函数

loss = model(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask,
                             start_positions=start_positions, end_positions=end_positions, switch_list=switches)

L s p a n + L n o _ a n s w e r L_{span}+L_{no\_answer} Lspan+Lno_answer

span_mask = (switch_list == 0).type(torch.FloatTensor).cuda()
start_losses = loss_fct(
                    start_logits, start_positions) * span_mask
end_losses = loss_fct(end_logits, end_positions) * span_mask
collections.OrderedDict()#实现对字典元素排序
collections.namedtuple()#类似于结构体的用法
json.dumps()	#将python对象编码成Json字符串

仅使用可以找大答案的问题

actual_text = " ".join(
    doc_tokens[start_position:(end_position + 1)])
cleaned_answer_text = " ".join(
    whitespace_tokenize(orig_answer_text))
if actual_text.find(cleaned_answer_text) == -1:
    logger.warning("Could not find answer: '%s' vs. '%s'",
                   actual_text, cleaned_answer_text)
    continue

过滤掉过长的问题

if len(orig_answer_text.split()) > max_answer_len:    logger.info(        "Omitting a long answer: '%s'", orig_answer_text)    continue
num_train_optimization_steps // int(save_chunk) #结果取整数

将答案范围设置为与a匹配并首先出现的字符串。

for i in range(len(index_and_score)):    if i >= n_best_size:        break    best_indexes.append(index_and_score[i][0])

BRET输入示例

tokens: [CLS] this singer of a rather blu ##ster ##y day also voiced what hedge ##hog ? [SEP] " a rather blu ##ster ##y day " is a w ##him ##sic ##al song from the walt disney musical film feature ##tte , " winnie the po ##oh and the blu ##ster ##y day " . it was written by robert & richard sherman and sung by jim cummings as " po ##oh " . james jonah cummings ( born november 3 , 1952 ) is an american voice actor and singer , who has appeared in almost 400 roles . he is known for vo ##icing the title character from " dark ##wing duck " , dr . robot ##nik from " sonic the hedge ##hog " , and pete . his other characters include winnie the po ##oh , ti ##gger , and the tasmanian devil . he has performed in numerous disney and dream ##works animation ##s including " ala ##ddin " , " the lion king " , " bal ##to " , " ant ##z " , " the road to el dora ##do " , " sh ##rek " , and " the princess and the frog " . he has also provided voice - over work for video games , such as " ice ##wind dale " , " fallout " , " " , " bald ##ur ' s gate " , " mass effect 2 " , " " , " " , " " , and " sp ##lat ##ter ##house " . [SEP]

输入长度不足补0

while len(input_ids) < max_seq_length:    input_ids.append(pad_token)    input_mask.append(0 if mask_padding_with_zero else 1)    segment_ids.append(pad_token_segment_id)    p_mask.append(1)

示例

input_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0

实验

下载程序

!git clone https://github.com/AkariAsai/learning_to_retrieve_reasoning_paths.git%cd /content/learning_to_retrieve_reasoning_paths!pip install -r requirements.txt

下载训练数据

%cd /content/learning_to_retrieve_reasoning_paths!mkdir data%cd data!mkdir hotpot%cd hotpot!gdown https://drive.google.com/uc?id=1_a8KliAHKIwrYRrHgHOlzM0Jon3AqZLs!mv hotpot_reader_train_data.json.json____ hotpot_reader_train_data.json!gdown https://drive.google.com/uc?id=1R4exuPDaV2yD18xUBsnNyQpXwn0ty5pc!mv nq_reader_train_data_public.json.json____ nq_reader_train_data_public.json!gdown https://drive.google.com/uc?id=1FB5gB9aM8rmbpIwYf-1o6lmxMYQsg_rP!mv squad_reader_train_data.json.json____ squad_reader_train_data.json!gdown https://drive.google.com/uc?id=1MysthH2TRYoJcK_eLOueoLeYR42T-JhB!ls

训练模型

数据集均采用SQuAD v.2 format

使用hotpot_dev_squad_v2.0_format.json训练

hotpot_reader_train_data.json太大训练不出来,这里用hotpot_dev_squad_v2.0_format.json只是为了体验下训练的过程!

%cd /content/learning_to_retrieve_reasoning_paths/reader/!python run_reader_confidence.py \--bert_model bert-base-uncased \--output_dir output_hotpot_bert_base \--train_file /content/learning_to_retrieve_reasoning_paths/data/hotpot/hotpot_dev_squad_v2.0_format.json \--predict_file /content/learning_to_retrieve_reasoning_paths/data/hotpot/hotpot_dev_squad_v2.0_format.json \--max_seq_length 384 \--do_train \--do_predict \--do_lower_case \--version_2_with_negative \--train_batch_size 16

–train_batch_size 根据显卡内存调整

–version_2_with_negative 使用负例数据训练

仅训练

%cd /content/learning_to_retrieve_reasoning_paths/reader/!python run_reader_confidence.py \--bert_model bert-base-uncased \--output_dir output_hotpot_bert_base \--train_file /content/learning_to_retrieve_reasoning_paths/data/hotpot/hotpot_dev_squad_v2.0_format.json \--max_seq_length 384 \--do_train \--do_lower_case \--version_2_with_negative \--train_batch_size 16

仅预测

%cd /content/learning_to_retrieve_reasoning_paths/reader/!python run_reader_confidence.py \--bert_model bert-base-uncased \--output_dir output_hotpot_bert_base \--predict_file /content/learning_to_retrieve_reasoning_paths/data/hotpot/hotpot_dev_squad_v2.0_format.json \--max_seq_length 384 \--do_predict \--do_lower_case \--version_2_with_negative \--train_batch_size 16

评估

下载评估数据集

%cd /content/learning_to_retrieve_reasoning_paths!mkdir data%cd data!mkdir hotpot%cd hotpot!gdown https://drive.google.com/uc?id=1MysthH2TRYoJcK_eLOueoLeYR42T-JhB!ls

评估模型

predictions.json 为模型预测后自动生成

%cd /content/learning_to_retrieve_reasoning_paths/reader/!wget https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/!mv index.html evaluate-v2.0.py!python evaluate-v2.0.py \/content/learning_to_retrieve_reasoning_paths/data/hotpot/hotpot_dev_squad_v2.0_format.json \/content/learning_to_retrieve_reasoning_paths/reader/output_hotpot_bert_base/predictions.json

评估数据:/content/learning_to_retrieve_reasoning_paths/data/hotpot/hotpot_dev_squad_v2.0_format.json

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

没有胡子的猫

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值