本文介绍一个超级简单的中文文本分类模型,模型的代码、数据集已整理好放于文末获取。通过该代码不仅可以实现文本分类,而且可以实现情感分析,情感检测等。
项目介绍
在本案例中,我们将训练一个中文文本分类模型。所用到的数据集根据新浪新闻rss订阅频道的历史数据筛选生成;数据集包含10个分类:‘体育’, ‘财经’, ‘房产’, ‘家居’, ‘教育’, ‘科技’, ‘时尚’, ‘时政’, ‘游戏’, ‘娱乐’。
文件总览:
代码介绍:
- data目录:存放用于文本分类模型用于训练和测试的数据。
- checkpoint目录:存放训练之后的权重文件;
- train.py:模型训练代码;
- predict.py:模型测试代码;
- preprocess.py:data数据集的预处理代码;
- model.py:文本分类模型代码;
- config.py:项目的参数文件;
- requirements.txt:记录项目所需python依赖包;
模型训练
1. 环境安装
此处默认安装好了python>=3.7版本的运行环境,本案例代码主要用到的python依赖库有:jieba、scikit_learn、tensorflow,我们可以在cmd命令行运行如下命令安装本案例代码需要的依赖。
pip install -r requirements.txt
2. config文件参数设置
在config文件中我们已经定义好了模型的参数,数据集的路径,模型权重的输出路径等。当然,如果我们想要训练自己的文本分类模型,可以通过修改训练数据路径来实现训练我们自己的文本分类模型。
3. 运行代码
trian.py文件的内容如下所示,模型的接口我们都封装好了,我们可以一键启动训练我们的模型,其中 cnn_model.train(3) 代表训练三个epoch。
from model import TextCNN
if __name__ == '__main__':
CNN_model = TextCNN()
CNN_model.train(3)
CNN_model.test()
训练过程:
模型使用
1. config文件参数设置
此处我们使用config.py文件默认的参数即可,不需要做修改。
2. 运行代码
predict.py文件的内容如下所示,模型的接口我们都封装好了,我们只需要指定需要进行预测的文本即可。下面我们使用一段关于NBA的赛事介绍文章,使用模型来预测该篇文章类型。
from model import TextCNN
if __name__ == '__main__':
CNN_model = TextCNN()
test_sentence = r"黄蜂vs湖人首发:科比冲击七连胜 火箭两旧将登场新浪体育讯北京时间3月28日,NBA常规赛洛杉矶湖人主场迎战新奥尔良黄蜂,赛前双方也公布了首发阵容:点击进入新浪体育视频直播室点击进入新浪体育图文直播室点击进入新浪体育NBA专题点击进入新浪NBA官方微博双方首发阵容:湖人队:德里克-费舍尔、科比-布莱恩特、罗恩-阿泰斯特、保罗-加索尔、安德鲁-拜纳姆黄蜂队:克里斯-保罗、马科-贝里内利、特雷沃-阿里扎、卡尔-兰德里、埃梅卡-奥卡福(新浪体育)"
result = CNN_model.predict(test_sentence)
print("预测结果:{}".format(result))
测试结果:
获取资料
项目代码经过严格调试,在windows下运行无bug。支付宝扫描下方图片链接获取资料的百度网盘下载地址。