这次做的任务牵涉到了BERT训练预料的下载。我们都知道BERT是在大规模无监督文本上训练的,那么它的训练数据从何而来呢 ? 从wiki上抓取文章来训练是一个很好的选择。
涉及的代码文件都被我上传至github了。 数据大的没传。
GitHub - xiaolilaoli/bert_data: how to get bert data
如果tf 报错,一般是一二代不兼容。用下面这句。
import tensorflow.compat.v1 as tf
一: wiki数据下载及处理,转tfrecord。
wiki数据的下载可以参考下面的博客。
比如我们选中下面这个文件。
下载后,需要使用WikiExtractor提取和整理数据集中的文本,使用步骤如下。
先安装:
- pip install wikiextractor
后在终端输入
- python -m wikiextractor.WikiExtractor -o 【目标文件路径】-b 【大小】 【源文件路径】
注意这个 大小 指的是提取出的文件大小,也就是如果写了1M, 输出时,凑够1M就输出一个文件。 如果写100M,那么每个输出文件都是100M
我写的1M,提取后如下图。
转换tfrecord
tfrecord是tensorflow的数据文件格式,转为tensorflow可以让读取更加容易。
首先我们要下载bert的词表,也就是vocab
之后去BERT官方代码库, 下载代码库放在本地。其实我们需要的是 tokenization.py和 create_pretraining_data.py 文件
找到 create_pretraining_data.py 这个代码文件,这个文件就是将wiki处理后的数据转换为BERT预训练 tfrecord文件的代码。
在服务器运行
python create_pretraining_data.py --input_file=/home/lsc/model/enwiki/AA/wiki_00
--output_file=/home/lsc/dataset/record/wiki_00.tfrecord --vocab_file=/home/lsc/model/uncased_L-12_H-768_A-12/vocab.txt --do_lower_case=True --max_seq_length=128 --max_predictions_per_seq=20 --masked_lm_prob=0.15 --random_seed=12345 --dupe_factor=5
或者在pycharm 中, 运行-编辑配置-形参中输入:
--input_file=/home/lsc/model/enwiki/AA/wiki_00
--output_file=/home/lsc/dataset/record/wiki_00.tfrecord --vocab_file=/home/lsc/model/uncased_L-12_H-768_A-12/vocab.txt --do_lower_case=True --max_seq_length=128 --max_predictions_per_seq=20 --masked_lm_prob=0.15 --random_seed=12345 --dupe_factor=5
记得输入,输出,及词表地址 改成自己的文件地址。运行即可得到后续训练需要的tfrecord文件。
二 glue数据下载及转换tfrecord。以QNLI为例。
如果你做nlp,glue数据集你一定很熟悉,我就不过多介绍了。我们看glue数据集是如何下载的。
glue有官方的下载代码。
有了代码后,你可以
在main函数这里添加args.tasks datadir。
或者直接运行时输入参数 --tasks='QNLI' --data_dir=''
数据目录就是你存放数据的地址。
以qnli为例, 下载后得到三个文件。下一步转为tfrecord数据
转tfrecord
转换的数据依然来源于BERT官方代码库 ,参见readme使用 run_classifier.py 代码文件。run_classifier.py
代码中包含了训练,推理和预测的代码,对于转化tfrecord数据集来说,这部分代码是多余的,可以将这部分代码注释掉,只保留转化数据集的代码.经过我探究,应该主要保留下面的代码即可。 我将修改数据格式部分的代码 重命名为 create_task_data.py
运行时输入 :
python create_task_data.py --task_name=QNLI --data_dir=/home/lsc/dataset/glue/QNLI --vocab_file=/home/lsc/model/uncased_L-12_H-768_A-12/vocab.txt --max_seq_length=128 --output_dir=/home/lsc/model/glue
注意地址。
此外,要加入QNLI的处理代码,原来的代码库中没有。如下图,可以看到 数据处理器中,只有四个(qnli是我自己加的),你如果想处理其他数据集,就搜下其他数据集的处理代码。将qnli处理器类放入代码文件,将名字加入列表。
...
class QnliProcessor(DataProcessor):
"""Processor for the QNLI data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")),
"dev_matched")
def get_labels(self):
"""See base class."""
return ["entailment", "not_entailment"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, line[0])
text_a = line[1]
text_b = line[2]
label = line[-1]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
...
"qnli": QnliProcessor,
...
即可得到输出的tfrecord数据