本文档介绍了如何使用BERT实现多类别文本分类任务,适合稍微了解BERT和文本分类的同学参考。
(一) 下载
首先,在github上clone谷歌的BERT项目,或者直接下载。项目地址
然后,下载中文预训练模型,地址
(二) 环境准备
tensorflow >= 1.11.0
注意:
- 在GPU上运行Tensorflow,需要CUDA版本和Tensorflow版本的对应。比如Tensorflow-1.11.0最高只能使用9.0版本的CUDA,否则加载时会出现找不到libcublas.so的错误。
- 安装TensorFlow时,如果出现无法卸载enum34的错误,可以用pip install *** --ignore_installed enum34命令先跳过。
(三) 数据准备
准备数据集,包括训练集、验证集、测试集,格式相同,每行为一个类别+文本,用“\t”间隔。(如果选择其他间隔符,需要修改run_classifier.py中_read_tsv方法)。
我做的是新闻文本分类,数据格式如下:
(四) 修改run_classifier.py文件
- 添加处理数据集的类,class ZbsProcessor(DataProcessor),分别实现以下方法:
def get_train_examples(self, data_dir): 读取训练集
def get_dev_examples(self, data_dir): 读取验证集
def get_test_examples(self, data_dir): 读取测试集
def get_labels(self, labels): 获得类别集合
def _create_examples(self, lines, set_type): 生成训练和验证样本
- 修改main函数。在第744行,将ZbsProcessor添加到processors中
processors = {
"cola": ColaProcessor,
"mnli": MnliProcessor,
"mrpc": MrpcProcessor,
"xnli": XnliProcessor,
"zbs": ZbsProcessor
}
- 原代码中,先判断是否train,然后获取训练样本,但是后面需要所有类别,所以需要改成先获取所有类别,然后判断判断是否train。即代码:
if FLAGS.do_train:
train_examples = processor.get_train_examples