一、长文本分类
参考项目来源:https://github.com/CLOVEXCWZ/Pytorch_LongText_Classification_Demo
1、数据集介绍
数据集使用的是搜狗新闻数据语料,下载链接:http://thuctc.thunlp.org/ ,新闻语料中主要有包含多个类别,由于考虑到样本量和样本均衡情况,该任务只选取前4种样本量比较多且较为均衡的类别作为项目的数据集,可通过链接:处理后的数据集百度云盘下载地址 提取码:mbt3
下载处理好的数据集的训练集sougou_train.txt和验证集sougou_dev.txt,下载后放在Pytorch_LongText_Classification_Demo-master\datas\SougouNews目录下。
数据集每一行的格式为: 文本+Tab+标签
数据集情况如下:
类别 | 数量 | 对应index |
sports | 85984 | 0 |
news | 82740 | 1 |
house | 71221 | 2 |
business | 61843 | 3 |
2、数据预处理
可运行文件dataprocess.py中的代码
train_char = load_dataset(model='train', leve='char', max_setence=10, max_words=50)
进行数据预处理生成词典,处理逻辑是对长文本按句子分割,再对句子按词或字分割,选取前max_setence的前max_words的字或词,然后统计出现的字或词的频率,按频率从高到低排序,选取排名前50000的字或词,生成pkl文件用于保存这些词或字的索引,最后将训练文本按词或字转成成对应的索引向量。
model='train'表示对训练数据做处理,'dev'表示对验证数据做处理;leve='char'表示对文本按字分割,'word'表示按词分割,生成的对应pkl文件为char_vocab.pkl和word_vocab.pkl
若运行过程中,以下代码出现bug
samples_b = np.array(samples_b)
改成这样即可
samples_b = np.array(samples_b, dtype=object)
3、模型训练
运行文件train_all.py可训练模型
用到的模型有fasttext、textcnn、textrcnn、textrnn和transformer
参数leve = 'char'表示训练字级别的模型,'word'表示训练词级别
模型结构介绍可参考:https://zhuanlan.zhihu.com/p/73176084
Transformer模型的介绍可参考:
训练结果:
网络 | 字级别(准确率) | 词级别(准确率) |
FastText | 0.9410 | 0.9650 |
TextCNN | 0.9662 | 0.9693 |
TextRNN | 0.9661 | 0.9702 |
TextRCNN | 0.9705 | 0.9712 |
Transformer | 0.9644 | 0.9684 |
二、短文本分类
短文本分类可参考链接:https://github.com/CLOVEXCWZ/Pytorch_Text_Classification_Demo
模型效果:
网络模型 | 准确率 |
FastText | 85.34% |
TextCNN | 89.62% |
TextRNN | 88.9% |
TextRCNN | 90.22% |
Transformer | 88.98% |
也可参考链接:https://github.com/649453932/Chinese-Text-Classification-Pytorch?tab=readme-ov-file ,使用的模型更多
模型效果:
模型 | 准确率 | 备注 |
TextCNN | 91.22% | Kim 2014 经典的CNN文本分类 |
TextRNN | 91.12% | BiLSTM |
TextRNN_Att | 90.90% | BiLSTM+Attention |
TextRCNN | 91.54% | BiLSTM+池化 |
FastText | 92.23% | bow+bigram+trigram, 效果出奇的好 |
DPCNN | 91.25% | 深层金字塔CNN |
Transformer | 89.91% | 效果较差 |
bert | 94.83% | bert + fc |
ERNIE | 94.61% | 比bert略差(说好的中文碾压bert呢) |