问答模型(七)——整合代码

from paddlenlp.datasets import load_dataset
import paddlenlp as ppnlp
from utils import prepare_train_features, prepare_validation_features
from functools import partial
from paddlenlp.metrics.squad import squad_evaluate, compute_prediction
import paddlenlp.transformers
import collections
import time
import json
from paddlenlp.data import Stack, Dict, Pad
from paddlenlp.datasets import load_dataset
import paddlenlp as ppnlp
from utils import prepare_train_features, prepare_validation_features
from functools import partial
from paddlenlp.metrics.squad import squad_evaluate, compute_prediction
import paddle
from paddlenlp.data import Stack, Dict, Pad
from paddlenlp.datasets import MapDataset
import collections
import time
import json
import pandas as pd

@paddle.no_grad()
def do_predict(model, data_loader):
    model.eval()

    all_start_logits = []
    all_end_logits = []
    tic_eval = time.time()

    for batch in data_loader:
        input_ids, token_type_ids = batch
        start_logits_tensor, end_logits_tensor = model(input_ids,
                                                       token_type_ids)

        for idx in range(start_logits_tensor.shape[0]):
            if len(all_start_logits) % 1000 == 0 and len(all_start_logits):
                print("Processing example: %d" % len(all_start_logits))
                print('time per 1000:', time.time() - tic_eval)
                tic_eval = time.time()

            all_start_logits.append(start_logits_tensor.numpy()[idx])
            all_end_logits.append(end_logits_tensor.numpy()[idx])

    all_predictions, _, _ = compute_prediction(
        data_loader.dataset.data, data_loader.dataset.new_data,
        (all_start_logits, all_end_logits), False, 20, 30)


    count = 0
    for example in data_loader.dataset.data:
        count += 1
        print()
        print('问题:',example['question'])
        print('原文:',''.join(example['context']))
        print('答案:',all_predictions[example['id']])
        if count >= 2:
            break
    
    model.train()
    query = example['question']
    text = ''.join(example['context'])
    answer = all_predictions[example['id']]

    return query, text, answer


def read(data_path):
    """This function returns the examples in the raw (text) form."""
    key = 0
    with open(data_path, encoding="utf-8") as f:
        durobust = json.load(f)
        for article in durobust["data"]:
            title = article.get("title", "")
            for paragraph in article["paragraphs"]:
                context = paragraph[
                    "context"]  # do not strip leading blank spaces GH-2585
                for qa in paragraph["qas"]:
                    answer_starts = [
                        answer["answer_start"]
                        for answer in qa.get("answers", '')
                    ]
                    answers = [
                        answer["text"] for answer in qa.get("answers", '')
                    ]
                    # Features currently used are "context", "question", and "answers".
                    # Others are extracted here for the ease of future expansions.
                    yield key, {
                        "id": qa["id"],
                        "title": title,
                        "context": context,
                        "question": qa["question"],
                        "answers": answers,
                        "answer_starts": answer_starts,
                    }
                    key += 1


def load_json(path, tokenizer):
    dev_ds = ppnlp.datasets.load_dataset(read, data_path = path, lazy=False)
    diction = []
    for idx in range(len(dev_ds)):
        a, dic = dev_ds[idx]
        dic = dict(dic)
        diction.append(dic)

    dev_ds = diction
    dev_ds = MapDataset(dev_ds)

    max_seq_length = 512
    doc_stride = 128



    dev_trans_func = partial(prepare_validation_features, 
                           max_seq_length=max_seq_length, 
                           doc_stride=doc_stride,
                           tokenizer=tokenizer)
                           
    dev_ds.map(dev_trans_func, batched=True)

    batch_size = 8


    dev_batch_sampler = paddle.io.BatchSampler(
        dev_ds, batch_size=batch_size, shuffle=False)

    dev_batchify_fn = lambda samples, fn=Dict({
        "input_ids": Pad(axis=0, pad_val=tokenizer.pad_token_id),
        "token_type_ids": Pad(axis=0, pad_val=tokenizer.pad_token_type_id)
    }): fn(samples)

    dev_data_loader = paddle.io.DataLoader(
        dataset=dev_ds,
        batch_sampler=dev_batch_sampler,
        collate_fn=dev_batchify_fn,
        return_list=True)


    return dev_data_loader

将方法全部放入pred_utils.py

import paddlenlp as ppnlp
import paddlenlp.transformers
import pred_utils



model = ppnlp.transformers.BertForQuestionAnswering.from_pretrained('checkpoint_raw8_2')
tokenizer = ppnlp.transformers.BertTokenizer.from_pretrained('checkpoint_raw8_2')

dev_data_loader = pred_utils.load_json('test.json', tokenizer)

query, text, answer = pred_utils.do_predict(model, dev_data_loader)

在test.py中编写调用的示例使代码更简洁,易于整合同学进行多线程等操作

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值