tensorflow 小于_基于tensorflow的文本分类

该博客介绍了使用TensorFlow进行文本分类的实践,包括数据预处理、模型选择(如bilstm、bilstm+attention、textcnn、rcnn、transformer)、配置参数(如词嵌入维度、学习率、epoch等)以及训练和测试过程。提供了一个基于复旦中文语料的20类文本分类任务,并分享了数据集、处理后的数据和预训练模型的链接。
摘要由CSDN通过智能技术生成

tensorflow-text-classification

数据集:复旦中文语料,包含20类
数据集下载地址:https://www.kesci.com/mw/dataset/5d3a9c86cf76a600360edd04/content
数据集下载好之后将其放置在data文件夹下;
修改globalConfig.py中的全局路径为自己项目的路径;
处理后的数据和已训练好保存的模型,在这里可以下载:
链接:https://pan.baidu.com/s/1ZHzO5e__-WFYAYFIt2Kmsg 提取码:vvzy

目录结构:
|--checkpint:保存模型目录
|--|--transformer:transformer模型保存位置;
|--config:配置文件;
|--|--fudanConfig.py:包含训练配置、模型配置、数据集配置;
|--|--globaConfig.py:全局配置文件,主要是全局路径、全局参数等;
|-- data:数据保存位置;
|--|--|--Fudan:复旦数据;
|--|--|--train:训练数据;
|--|--|--answer:测试数据;
|--dataset:创建数据集,对数据进行处理的一些操作;
|--images:结果可视化图片保存位置;
|--models:模型保存文件;
|--process:对原始数据进行处理后的数据;
|--tensorboard:tensorboard可视化文件保存位置,暂时未用到;
|--utils:辅助函数保存位置,包括word2vec训练词向量、评价指标计算、结果可视化等;
|--main.py:主运行文件,选择模型、训练、测试和预测;

初始配置:

  • 词嵌入维度:200

  • 学习率:0.001

  • epoch:50

  • 词汇表大小:6000+2(加2是PAD和UNK)

  • 文本最大长度:600

  • 每多少个step进行验证:100

  • 每多少个step进行存储模型:100

环境:

  • python=>=3.6

  • tensorflow==1.15.0

当前支持的模型:

  • bilstm

  • bilstm+attention

  • textcnn

  • rcnn

  • transformer

说明

数据的输入格式:
(1)分词后去除掉停止词,再对词语进行词频统计,取频数最高的前6000个词语作为词汇表;
(2)像词汇表中加入PAD和UNK,实际上的词汇表的词语总数为6000+2=6002;
(3)当句子长度大于指定的最大长度,进行裁剪,小于最大长度,在句子前面用PAD进行填充;
(4)如果句子中的词语在词汇表中没有出现则用UNK进行代替;
(5)输入到网络中的句子实际上是进行分词后的词语映射的id,比如:
(6)输入的标签是要经过onehot编码的;
"""
"我喜欢上海",
"我喜欢打羽毛球",
"""
词汇表:['我','喜欢','打','上海','羽毛球'],对应映射:[2,3,4,5,6],0对应PAD,1对应UNK
得到:
[
[0,2,3,5],
[2,3,4,6],
]

python main.py --model transformer --saver_dir checkpoint/transformer --save_png images/transformer  --train  --test  --predict 

参数说明:

  • --model:选择模型,可选[transformer、bilstm、bilstmattn、textcnn、rcnn]

  • --saver_dir:模型保存位置,一般是checkpoint+模型名称

  • --save_png:结果可视化保存位置,一般是images+模型名称

  • --train:是否进行训练,默认为False

  • --test:是否进行测试,默认为False

  • --predict:是否进行预测,默认为False

结果

以transformer为例:
部分训练结果:2020-11-01T10:43:16.955322, step: 1300, loss: 5.089711, acc: 0.8546,precision: 0.3990, recall: 0.4061, f_beta: 0.3977 *Epoch: 83train: step: 1320, loss: 0.023474, acc: 0.9922, recall: 0.8444, precision: 0.8474, f_beta: 0.8457Epoch: 84train: step: 1340, loss: 0.000000, acc: 1.0000, recall: 0.7500, precision: 0.7500, f_beta: 0.7500Epoch: 85train: step: 1360, loss: 0.000000, acc: 1.0000, recall: 0.5500, precision: 0.5500, f_beta: 0.5500Epoch: 86Epoch: 87train: step: 1380, loss: 0.000000, acc: 1.0000, recall: 0.7500, precision: 0.7500, f_beta: 0.7500Epoch: 88train: step: 1400, loss: 0.000000, acc: 1.0000, recall: 0.7000, precision: 0.7000, f_beta: 0.7000
开始验证。。。2020-11-01T10:44:07.347359, step: 1400, loss: 5.111372, acc: 0.8506,precision: 0.4032, recall: 0.4083, f_beta: 0.3982 *Epoch: 89train: step: 1420, loss: 0.000000, acc: 1.0000, recall: 0.5500, precision: 0.5500, f_beta: 0.5500Epoch: 90train: step: 1440, loss: 0.000000, acc: 1.0000, recall: 0.5500, precision: 0.5500, f_beta: 0.5500Epoch: 91Epoch: 92train: step: 1460, loss: 0.000000, acc: 1.0000, recall: 0.7000, precision: 0.7000, f_beta: 0.7000Epoch: 93train: step: 1480, loss: 0.000000, acc: 1.0000, recall: 0.7500, precision: 0.7500, f_beta: 0.7500Epoch: 94train: step: 1500, loss: 0.000000, acc: 1.0000, recall: 0.6000, precision: 0.6000, f_beta: 0.6000
开始验证。。。2020-11-01T10:44:57.645305, step: 1500, loss: 5.206666, acc: 0.8521,precision: 0.4003, recall: 0.4040, f_beta: 0.3957 Epoch: 95train: step: 1520, loss: 0.000000, acc: 1.0000, recall: 0.6000, precision: 0.6000, f_beta: 0.6000Epoch: 96Epoch: 97train: step: 1540, loss: 0.000000, acc: 1.0000, recall: 0.7500, precision: 0.7500, f_beta: 0.7500Epoch: 98train: step: 1560, loss: 0.000000, acc: 1.0000, recall: 0.7000, precision: 0.7000, f_beta: 0.7000Epoch: 99train: step: 1580, loss: 0.000000, acc: 1.0000, recall: 0.8000, precision: 0.8000, f_beta: 0.8000Epoch: 100train: step: 1600, loss: 0.000000, acc: 1.0000, recall: 0.5000, precision: 0.5000, f_beta: 0.5000
开始验证。。。2020-11-01T10:45:47.867190, step: 1600, loss: 5.080955, acc: 0.8566,precision: 0.4087, recall: 0.4131, f_beta: 0.4036 *<Figure size 1000x600 with 10 Axes>
绘图完成了。。。
开始进行测试。。。
计算Precision, Recall and F1-Score...precision recall f1-score supportAgriculture 0.89 0.90 0.89 1022Art 0.80 0.95 0.86 742Communication 0.19 0.26 0.22 27Computer 0.95 0.94 0.94 1358Economy 0.86 0.91 0.89 1601Education 1.00 0.11 0.21 61Electronics 0.35 0.39 0.37 28Energy 1.00 0.03 0.06 33Enviornment 0.88 0.96 0.92 1218History 0.79 0.48 0.60 468Law 1.00 0.12 0.21 52Literature 0.00 0.00 0.00 34Medical 0.50 0.13 0.21 53Military 0.33 0.01 0.03 76Mine 1.00 0.03 0.06 34Philosophy 1.00 0.04 0.09 45Politics 0.73 0.91 0.81 1026Space 0.84 0.86 0.85 642Sports 0.93 0.91 0.92 1254Transport 0.33 0.03 0.06 59accuracy 0.86 9833macro avg 0.72 0.45 0.46 9833weighted avg 0.85 0.86 0.84 9833

结果可视化图片如下:

f5b38b71ed01bc93ccb0d3064c16fcb2.png

进行预测。。。

开始预测文本的类别。。。
输入的文本是:自动化学报ACTA AUTOMATICA SINICA1997年 第23卷 第4期 Vol.23 No.4 1997一种在线建模方法的研究1)赵希男 粱三龙 潘德惠摘 要 针对一类系统提出了一种通用性...
预测的类别是:Computer
真实的类别是:Computer================================================
输入的文本是:航空动力学报JOURNAL OF AEROSPACE POWER1999年 第14卷 第1期 VOL.14 No.1 1999变几何涡扇发动机几何调节对性能的影响朱之丽 李 东摘要:本文以高推重比涡扇...
预测的类别是:Space
真实的类别是:Space================================================
输入的文本是:【 文献号 】1-4242【原文出处】图书馆论坛【原刊地名】广州【原刊期号】199503【原刊页号】13-15【分 类 号】G9【分 类 名】图书馆学、信息科学、资料工作【 作 者 】周坚宇【复印期...
预测的类别是:Sports
真实的类别是:Sports================================================
输入的文本是:产业与环境INDUSTRY AND ENVIRONMENT1998年 第20卷 第4期 Vol.20 No.4 1998科技期刊采矿——事实与数字引 言本期《产业与环境》中的向前看文章并没有十分详细地...
预测的类别是:Enviornment
真实的类别是:Enviornment================================================
输入的文本是:环境技术ENVIRONMENTAL TECHNOLOGY1999年 第3期 No.3 1999正弦振动试验中物理计算闫立摘要:本文通过阐述正弦振动试验技术涉及的物理概念、力学原理,编写了较适用的C语言...
预测的类别是:Space
真实的类别是:Enviornment================================================

下面是一些实现的对比:

transformer

评价指标precisionrecallf1-scoresupport
accuracy0.869833
macro avg0.720.450.469833
weighted avg0.850.860.849833

bistm

评价指标precisionrecallf1-scoresupport
accuracy0.779833
macro avg0.470.400.419833
weighted avg0.760.770.769833

bilstmattn

评价指标precisionrecallf1-scoresupport
accuracy0.929833
macro avg0.700.640.659833
weighted avg0.930.920.929833

textrcnn

评价指标precisionrecallf1-scoresupport
accuracy0.899833
macro avg0.710.460.489833
weighted avg0.880.890.879833

rcnn

 很奇怪,rcnn网络并没有得到有效的训练

评价指标precisionrecallf1-scoresupport
accuracy0.169833
macro avg0.010.050.029833
weighted avg0.040.160.059833

十分感谢以下仓库,给了自己很多参考:
https://github.com/jiangxinyang227/NLP-Project/tree/master/text_classifier 
https://github.com/gaussic/text-classification-cnn-rnn

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
### 文本分类 #### 数据预处理 要求训练集和测试集分开存储,对于中文的数据必须先分词,对分词后的词用空格符分开,并且将标签连接到每条数据的尾部,标签和句子用分隔符\分开。具体的如下: * 今天 的 天气 真好\积极 #### 文件结构介绍 * config文件:配置各种模型的配置参数 * data:存放训练集和测试集 * ckpt_model:存放checkpoint模型文件 * data_helpers:提供数据处理的方法 * pb_model:存放pb模型文件 * outputs:存放vocab,word_to_index, label_to_index, 处理后的数据 * models:存放模型代码 * trainers:存放训练代码 * predictors:存放预测代码 #### 训练模型 * python train.py --config_path="config/textcnn_config.json" #### 预测模型 * 预测代码都在predictors/predict.py中,初始化Predictor对象,调用predict方法即可。 #### 模型的配置参数详述 ##### textcnn:基于textcnn的文本分类 * model_name:模型名称 * epochs:全样本迭代次数 * checkpoint_every:迭代多少步保存一次模型文件 * eval_every:迭代多少步验证一次模型 * learning_rate:学习速率 * optimization:优化算法 * embedding_size:embedding层大小 * num_filters:卷积核的数量 * filter_sizes:卷积核的尺寸 * batch_size:批样本大小 * sequence_length:序列长度 * vocab_size:词汇表大小 * num_classes:样本的类别数,二分类时置为1,多分类时置为实际类别数 * keep_prob:保留神经元的比例 * l2_reg_lambda:L2正则化的系数,主要对全连接层的参数正则化 * max_grad_norm:梯度阶段临界值 * train_data:训练数据的存储路径 * eval_data:验证数据的存储路径 * stop_word:停用词表的存储路径 * output_path:输出路径,用来存储vocab,处理后的训练数据,验证数据 * word_vectors_path:词向量的路径 * ckpt_model_path:checkpoint 模型的存储路径 * pb_model_path:pb 模型的存储路径 ##### bilstm:基于bilstm的文本分类 * model_name:模型名称 * epochs:全样本迭代次数 * checkpoint_every:迭代多少步保存一次模型文件 * eval_every:迭代多少步验证一次模型 * learning_rate:学习速率 * optimization:优化算法 * embedding_size:embedding层大小 * hidden_sizes:lstm的隐层大小,列表对象,支持多层lstm,只要在列表中添加相应的层对应的隐层大小 * batch_size:批样本大小 * sequence_length:序列长度 * vocab_size:词汇表大小 * num_classes:样本的类别数,二分类时置为1,多分类时置为实际类别数 * keep_prob:保留神经元的比例 * l2_reg_lambda:L2正则化的系数,主要对全连接层的参数正则化 * max_grad_norm:梯度阶段临界值 * train_data:训练数据的存储路径 * eval_data:验证数据的存储路径 * stop_word:停用词表的存储路径 * output_path:输出路径,用来存储vocab,处理后的训练数据,验证数据 * word_vectors_path:词向量的路径 * ckpt_model_path:checkpoint 模型的存储路径 * pb_model_path:pb 模型的存储路径 ##### bilstm atten:基于bilstm + attention 的文本分类 * model_name:模型名称 * epochs:全样本迭代次数 * checkpoint_every:迭代多少步保存一次模型文件 * eval_every:迭代多少步验证一次模型 * learning_rate:学习速率 * optimization:优化算法 * embedding_size:embedding层大小 * hidd
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值