Paddle2.0实现中文新闻文本标题分类
本项目仅用于参考,提供思路和想法并非标准答案!请谨慎抄袭!
中文新闻文本标题分类Paddle2.0版本基线(非官方)
非官方,三岁出品!(虽水必精)
调优小建议
本项目基线的值不会很高,需要自行调参来提高效果。
优化建议:
- 修改模型 现在是线性模型可以尝试修改更为复杂的
对于nlp项目更加友好的(具体的我也不是很清楚) - 调整学习率来调整我们最好效果的查找
- 可以通过对已有模型进一步训练得到较好的效果
- ……
数据集地址
https://aistudio.baidu.com/aistudio/datasetdetail/75812
任务描述
基于THUCNews数据集的文本分类, THUCNews是根据新浪新闻RSS订阅频道2005~2011年间的历史数据筛选过滤生成,包含74万篇新闻文档,参赛者需要根据新闻标题的内容用算法来判断该新闻属于哪一类别
数据说明
THUCNews是根据新浪新闻RSS订阅频道2005~2011年间的历史数据筛选过滤生成,包含74万篇新闻文档(2.19 GB),均为UTF-8纯文本格式。在原始新浪新闻分类体系的基础上,重新整合划分出14个候选分类类别:财经、彩票、房产、股票、家居、教育、科技、社会、时尚、时政、体育、星座、游戏、娱乐。
已将训练集按照“标签ID+\t+标签+\t+原文标题”的格式抽取出来,可以直接根据新闻标题进行文本分类任务,希望答题者能够给出自己的解决方案。
训练集格式 标签ID+\t+标签+\t+原文标题 测试集格式 原文标题
提交答案
考试提交,需要提交模型代码项目版本和结果文件。结果文件为TXT文件格式,命名为result.txt,文件内的字段需要按照指定格式写入。
1.每个类别的行数和测试集原始数据行数应一一对应,不可乱序
2.输出结果应检查是否为83599行数据,否则成绩无效
3.输出结果文件命名为result.txt,一行一个类别,样例如下:
···
游戏
财经
时政
股票
家居
科技
社会
房产
教育
星座
科技
股票
游戏
财经
时政
股票
家居
科技
社会
房产
教育
···
代码思路说明
根据题目可以知道这个是一个经典的nlp
任务。
根据nlp
任务处理的一般流程,我们需要进行以下几个步骤:
- 数据处理并转换成词向量
- 模型的搭建
- 数据的训练
- 模型读取并推理数据得到结果
那么话不多说我们开始!
数据集解压
! pip install -U paddlepaddle==2.0.1
! unzip -oq /home/aistudio/data/data75812/新闻文本标签分类.zip
import paddle
import numpy as np
import matplotlib.pyplot as plt
import paddle.nn as nn
import os
import numpy as np
print(paddle.__version__) # 查看当前版本
# cpu/gpu环境选择,在 paddle.set_device() 输入对应运行设备。
# device = paddle.set_device('gpu')
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
from collections import Sized
2021-03-27 12:21:25,020 - INFO - font search path ['/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/ttf', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/afm', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/pdfcorefonts']
2021-03-27 12:21:25,357 - INFO - generated new fontManager
2.0.1
数据处理
首先我们考虑词向量的书写方式。
我们先制作词典(此处词典已经制作完成,我们直接读取就好了,词典制作过程会放在留言中)
我们把词典和我们的数据集进行对应,制作完成一个纯数字的对应码
得到对应码以后进行输出测试是否正确。
数据无误进行填充,把数据码用特殊标签进行替代完成数据长度相同的内容
检验数据长度
数据读取(字典、数据集)
# 字典读取
def get_dict_len(d_path):
with open(d_path, 'r', encoding='utf-8') as f:
line = eval(f.readlines()[0])
return line
word_dict = get_dict_len('新闻文本标签分类/dict.txt')
# 训练集和验证集读取
set = []
def dataset(datapath): # 数据集读取代码
with open(datapath)as f:
for i in f.readlines():
data = []
dataset = i[:i.rfind('\t')].split(',') # 获取文字内容
dataset = np.array(dataset)
data.append(dataset)
label = np.array(i[i.rfind('\t')+1:-1]) # 获取标签
data.append(label)
set.append(data)
return set
train_dataset = dataset('新闻文本标签分类/Train_IDs.txt')
val_dataset = dataset('新闻文本标签分类/Val_IDs.txt')
数据初始化
定义一些需要的值
# 初始数据准备
vocab_size = len(word_dict) + 1 # 字典长度加1
print(vocab_size)
emb_size = 256 # 神经网络长度
seq_len = 30 # 数据集长度(需要扩充的长度)
batch_size = 32 # 批处理大小
epochs = 2 # 训练轮数
pad_id = word_dict['<unk>'] # 空的填充内容值
nu=["财经","彩票","房产","股票","家居","教育","科技","社会","时尚","时政","体育","星座","游戏","娱乐"]
# 生成句子列表(数据码生成文本)
def ids_to_str(ids):
# print(ids)
words = []
for k in ids:
w = list(word_dict)[eval(k)]
words.append(w if isinstance(w, str) else w.decode('ASCII'))
return " ".join(words)
5308
数据查看
查看数据是否正确如有异常及时修改
# 查看数据内容
for i in train_dataset:
sent = i[0]
label = int(i[1