问答模型训练(一)

参考AI Studio----基于bert的模型的机器阅读理解

安装paddle2.0.2框架
  • 执行以下命令安装(推荐使用百度源):
python -m pip install paddlepaddle==2.0.2 -i https://mirror.baidu.com/pypi/simple
修改下载到的paddle的源码

按理说下载好的库与参考链接中使用的库相同,在安装好该环境后就能直接使用。但实际情况是,pip安装的库与paddle的github仓库中的文件有所出入,因此将使用到的文件对照github进行更新(校园网为什么要屏蔽github o(╥﹏╥)o)

  • 在dataset.py中添加load_dataset方法
def load_dataset(path_or_read_func,
                name=None,
                data_files=None,
                splits=None,
                lazy=None,
                **kwargs):

   if inspect.isfunction(path_or_read_func):
       assert lazy is not None, "lazy can not be None in custom mode."
       kwargs['name'] = name
       kwargs['data_files'] = data_files
       kwargs['splits'] = splits
       custom_kwargs = {}
       for name in inspect.signature(path_or_read_func).parameters.keys():
           if name in kwargs.keys():
               custom_kwargs[name] = kwargs[name]

       reader_instance = SimpleBuilder(lazy=lazy, read_func=path_or_read_func)
       return reader_instance.read(**custom_kwargs)
   else:
       try:
           reader_cls = import_main_class(path_or_read_func)
       except ModuleNotFoundError:
           datasets = load_from_hf(
               path_or_read_func, name=name, splits=splits, **kwargs)
       else:
           reader_instance = reader_cls(lazy=lazy, name=name, **kwargs)

           # Check if selected name and split is valid in this DatasetBuilder
           if hasattr(reader_instance, 'BUILDER_CONFIGS'):
               if name in reader_cls.BUILDER_CONFIGS.keys():
                   split_names = reader_cls.BUILDER_CONFIGS[name][
                       'splits'].keys()
               else:
                   raise ValueError(
                       'Invalid name "{}". Should be one of {}.'.format(
                           name, list(reader_cls.BUILDER_CONFIGS.keys())))
           elif hasattr(reader_instance, 'SPLITS'):
               split_names = reader_instance.SPLITS.keys()
           else:
               raise AttributeError(
                   "Either 'SPLITS' or 'BUILDER_CONFIGS' must be implemented for DatasetBuilder."
               )

           selected_splits = []
           if isinstance(splits, list) or isinstance(splits, tuple):
               selected_splits.extend(splits)
           else:
               selected_splits += [splits]

           for split_name in selected_splits:
               if split_name not in split_names and split_name != None:
                   raise ValueError('Invalid split "{}". Should be one of {}.'.
                                    format(split_name, list(split_names)))

           datasets = reader_instance.read_datasets(
               data_files=data_files, splits=splits)
       return datasets
  • 在collate.py中添加Dict类
class Dict(object):
    def __init__(self, fn):
        assert isinstance(fn, (
            dict)), 'Input pattern not understood. The input of Dict must be a dict with key of input column name and value of collate_fn ' \
                    'Received fn=%s' % (str(fn))

        self._fn = fn

        for col_name, ele_fn in self._fn.items():
            assert callable(
                ele_fn
            ), 'Batchify functions must be callable! type(fn[%d]) = %s' % (
                col_name, str(type(ele_fn)))

    def __call__(self, data):
        ret = []
        for col_name, ele_fn in self._fn.items():
            result = ele_fn([ele[col_name] for ele in data])
            if isinstance(result, (tuple, list)):
                ret.extend(result)
            else:
                ret.append(result)
        return tuple(ret)
  • 修改chnsenticrop.py,解决掉TSVDataset
import collections
import json
import os

from paddle.dataset.common import md5file
from paddle.utils.download import get_path_from_url
from paddlenlp.utils.env import DATA_HOME
from . import DatasetBuilder

__all__ = ['ChnSentiCorp']


class ChnSentiCorp(DatasetBuilder):
    """
    ChnSentiCorp (by Tan Songbo at ICT of Chinese Academy of Sciences, and for
    opinion mining)

    """

    URL = "https://bj.bcebos.com/paddlenlp/datasets/ChnSentiCorp.zip"
    MD5 = "7ef61b08ad10fbddf2ba97613f071561"
    META_INFO = collections.namedtuple('META_INFO', ('file', 'md5'))
    SPLITS = {
        'train': META_INFO(
            os.path.join('ChnSentiCorp', 'ChnSentiCorp', 'train.tsv'),
            '689360c4a4a9ce8d8719ed500ae80907'),
        'dev': META_INFO(
            os.path.join('ChnSentiCorp', 'ChnSentiCorp', 'dev.tsv'),
            '20c77cc2371634731a367996b097ec0a'),
        'test': META_INFO(
            os.path.join('ChnSentiCorp', 'ChnSentiCorp', 'test.tsv'),
            '9b4dc7d1e4ada48c645b7e938592f49c'),
    }

    def _get_data(self, mode, **kwargs):
        """Downloads dataset."""
        default_root = os.path.join(DATA_HOME, self.__class__.__name__)
        filename, data_hash = self.SPLITS[mode]
        fullname = os.path.join(default_root, filename)
        if not os.path.exists(fullname) or (data_hash and
                                            not md5file(fullname) == data_hash):
            get_path_from_url(self.URL, default_root, self.MD5)

        return fullname

    def _read(self, filename, split):
        """Reads data."""
        with open(filename, 'r', encoding='utf-8') as f:
            head = None
            for line in f:
                data = line.strip().split("\t")
                if not head:
                    head = data
                else:
                    if split == 'train':
                        label, text = data
                        yield {"text": text, "label": label, "qid": ''}
                    elif split == 'dev':
                        qid, label, text = data
                        yield {"text": text, "label": label, "qid": qid}
                    elif split == 'test':
                        qid, text = data
                        yield {"text": text, "label": '', "qid": qid}

    def get_labels(self):
        """
        Return labels of the ChnSentiCorp object.
        """
        return ["0", "1"]

真是要疯掉了,底层调用的每个包都有问题,重装算了

气死 啦😤

  • bug一个接一个,都出在库文件里,现在要上演移形换影大发,把paddle的文件全都替换掉!

paddle居然也用jieba分词器分词

class JiebaTokenizer(BaseTokenizer):
    def __init__(self, vocab):
        super(JiebaTokenizer, self).__init__(vocab)
        self.tokenizer = jieba.Tokenizer()
        # initialize tokenizer
        self.tokenizer.FREQ = {key: 1 for key in self.vocab.token_to_idx.keys()}
        self.tokenizer.total = len(self.tokenizer.FREQ)
        self.tokenizer.initialized = True
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值