参考AI Studio----基于bert的模型的机器阅读理解
安装paddle2.0.2框架
- 执行以下命令安装(推荐使用百度源):
python -m pip install paddlepaddle==2.0.2 -i https://mirror.baidu.com/pypi/simple
修改下载到的paddle的源码
按理说下载好的库与参考链接中使用的库相同,在安装好该环境后就能直接使用。但实际情况是,pip安装的库与paddle的github仓库中的文件有所出入,因此将使用到的文件对照github进行更新(校园网为什么要屏蔽github o(╥﹏╥)o)
- 在dataset.py中添加load_dataset方法
def load_dataset(path_or_read_func,
name=None,
data_files=None,
splits=None,
lazy=None,
**kwargs):
if inspect.isfunction(path_or_read_func):
assert lazy is not None, "lazy can not be None in custom mode."
kwargs['name'] = name
kwargs['data_files'] = data_files
kwargs['splits'] = splits
custom_kwargs = {}
for name in inspect.signature(path_or_read_func).parameters.keys():
if name in kwargs.keys():
custom_kwargs[name] = kwargs[name]
reader_instance = SimpleBuilder(lazy=lazy, read_func=path_or_read_func)
return reader_instance.read(**custom_kwargs)
else:
try:
reader_cls = import_main_class(path_or_read_func)
except ModuleNotFoundError:
datasets = load_from_hf(
path_or_read_func, name=name, splits=splits, **kwargs)
else:
reader_instance = reader_cls(lazy=lazy, name=name, **kwargs)
# Check if selected name and split is valid in this DatasetBuilder
if hasattr(reader_instance, 'BUILDER_CONFIGS'):
if name in reader_cls.BUILDER_CONFIGS.keys():
split_names = reader_cls.BUILDER_CONFIGS[name][
'splits'].keys()
else:
raise ValueError(
'Invalid name "{}". Should be one of {}.'.format(
name, list(reader_cls.BUILDER_CONFIGS.keys())))
elif hasattr(reader_instance, 'SPLITS'):
split_names = reader_instance.SPLITS.keys()
else:
raise AttributeError(
"Either 'SPLITS' or 'BUILDER_CONFIGS' must be implemented for DatasetBuilder."
)
selected_splits = []
if isinstance(splits, list) or isinstance(splits, tuple):
selected_splits.extend(splits)
else:
selected_splits += [splits]
for split_name in selected_splits:
if split_name not in split_names and split_name != None:
raise ValueError('Invalid split "{}". Should be one of {}.'.
format(split_name, list(split_names)))
datasets = reader_instance.read_datasets(
data_files=data_files, splits=splits)
return datasets
- 在collate.py中添加Dict类
class Dict(object):
def __init__(self, fn):
assert isinstance(fn, (
dict)), 'Input pattern not understood. The input of Dict must be a dict with key of input column name and value of collate_fn ' \
'Received fn=%s' % (str(fn))
self._fn = fn
for col_name, ele_fn in self._fn.items():
assert callable(
ele_fn
), 'Batchify functions must be callable! type(fn[%d]) = %s' % (
col_name, str(type(ele_fn)))
def __call__(self, data):
ret = []
for col_name, ele_fn in self._fn.items():
result = ele_fn([ele[col_name] for ele in data])
if isinstance(result, (tuple, list)):
ret.extend(result)
else:
ret.append(result)
return tuple(ret)
- 修改chnsenticrop.py,解决掉TSVDataset
import collections
import json
import os
from paddle.dataset.common import md5file
from paddle.utils.download import get_path_from_url
from paddlenlp.utils.env import DATA_HOME
from . import DatasetBuilder
__all__ = ['ChnSentiCorp']
class ChnSentiCorp(DatasetBuilder):
"""
ChnSentiCorp (by Tan Songbo at ICT of Chinese Academy of Sciences, and for
opinion mining)
"""
URL = "https://bj.bcebos.com/paddlenlp/datasets/ChnSentiCorp.zip"
MD5 = "7ef61b08ad10fbddf2ba97613f071561"
META_INFO = collections.namedtuple('META_INFO', ('file', 'md5'))
SPLITS = {
'train': META_INFO(
os.path.join('ChnSentiCorp', 'ChnSentiCorp', 'train.tsv'),
'689360c4a4a9ce8d8719ed500ae80907'),
'dev': META_INFO(
os.path.join('ChnSentiCorp', 'ChnSentiCorp', 'dev.tsv'),
'20c77cc2371634731a367996b097ec0a'),
'test': META_INFO(
os.path.join('ChnSentiCorp', 'ChnSentiCorp', 'test.tsv'),
'9b4dc7d1e4ada48c645b7e938592f49c'),
}
def _get_data(self, mode, **kwargs):
"""Downloads dataset."""
default_root = os.path.join(DATA_HOME, self.__class__.__name__)
filename, data_hash = self.SPLITS[mode]
fullname = os.path.join(default_root, filename)
if not os.path.exists(fullname) or (data_hash and
not md5file(fullname) == data_hash):
get_path_from_url(self.URL, default_root, self.MD5)
return fullname
def _read(self, filename, split):
"""Reads data."""
with open(filename, 'r', encoding='utf-8') as f:
head = None
for line in f:
data = line.strip().split("\t")
if not head:
head = data
else:
if split == 'train':
label, text = data
yield {"text": text, "label": label, "qid": ''}
elif split == 'dev':
qid, label, text = data
yield {"text": text, "label": label, "qid": qid}
elif split == 'test':
qid, text = data
yield {"text": text, "label": '', "qid": qid}
def get_labels(self):
"""
Return labels of the ChnSentiCorp object.
"""
return ["0", "1"]
真是要疯掉了,底层调用的每个包都有问题,重装算了
气死 啦😤
- bug一个接一个,都出在库文件里,现在要上演移形换影大发,把paddle的文件全都替换掉!
paddle居然也用jieba分词器分词
class JiebaTokenizer(BaseTokenizer):
def __init__(self, vocab):
super(JiebaTokenizer, self).__init__(vocab)
self.tokenizer = jieba.Tokenizer()
# initialize tokenizer
self.tokenizer.FREQ = {key: 1 for key in self.vocab.token_to_idx.keys()}
self.tokenizer.total = len(self.tokenizer.FREQ)
self.tokenizer.initialized = True