Paddle2.0实现中文新闻文本标题分类

项目说明,本项目是李宏毅老师在飞桨授权课程的作业解析
课程 传送门
该项目AiStudio项目 传送门
数据集 传送门

本项目仅用于参考,提供思路和想法并非标准答案!请谨慎抄袭!

中文新闻文本标题分类Paddle2.0版本基线(非官方)

非官方,三岁出品!(虽水必精)

调优小建议

本项目基线的值不会很高,需要自行调参来提高效果。
优化建议:

  • 修改模型 现在是线性模型可以尝试修改更为复杂的
    对于nlp项目更加友好的(具体的我也不是很清楚)
  • 调整学习率来调整我们最好效果的查找
  • 可以通过对已有模型进一步训练得到较好的效果
  • ……

数据集地址

https://aistudio.baidu.com/aistudio/datasetdetail/75812

任务描述

基于THUCNews数据集的文本分类, THUCNews是根据新浪新闻RSS订阅频道2005~2011年间的历史数据筛选过滤生成,包含74万篇新闻文档,参赛者需要根据新闻标题的内容用算法来判断该新闻属于哪一类别

数据说明

THUCNews是根据新浪新闻RSS订阅频道2005~2011年间的历史数据筛选过滤生成,包含74万篇新闻文档(2.19 GB),均为UTF-8纯文本格式。在原始新浪新闻分类体系的基础上,重新整合划分出14个候选分类类别:财经、彩票、房产、股票、家居、教育、科技、社会、时尚、时政、体育、星座、游戏、娱乐。

已将训练集按照“标签ID+\t+标签+\t+原文标题”的格式抽取出来,可以直接根据新闻标题进行文本分类任务,希望答题者能够给出自己的解决方案。

训练集格式 标签ID+\t+标签+\t+原文标题 测试集格式 原文标题

提交答案

考试提交,需要提交模型代码项目版本结果文件。结果文件为TXT文件格式,命名为result.txt,文件内的字段需要按照指定格式写入。

1.每个类别的行数和测试集原始数据行数应一一对应,不可乱序

2.输出结果应检查是否为83599行数据,否则成绩无效

3.输出结果文件命名为result.txt,一行一个类别,样例如下:

···

游戏

财经

时政

股票

家居

科技

社会

房产

教育

星座

科技

股票

游戏

财经

时政

股票

家居

科技

社会

房产

教育

···

代码思路说明

根据题目可以知道这个是一个经典的nlp任务。
根据nlp任务处理的一般流程,我们需要进行以下几个步骤:

  • 数据处理并转换成词向量
  • 模型的搭建
  • 数据的训练
  • 模型读取并推理数据得到结果

那么话不多说我们开始!

数据集解压

! pip install -U paddlepaddle==2.0.1
! unzip -oq /home/aistudio/data/data75812/新闻文本标签分类.zip
import paddle
import numpy as np
import matplotlib.pyplot as plt
import paddle.nn as nn
import os
import numpy as np

print(paddle.__version__)  # 查看当前版本

# cpu/gpu环境选择,在 paddle.set_device() 输入对应运行设备。
# device = paddle.set_device('gpu')
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Sized
2021-03-27 12:21:25,020 - INFO - font search path ['/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/ttf', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/afm', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/pdfcorefonts']
2021-03-27 12:21:25,357 - INFO - generated new fontManager


2.0.1

数据处理

首先我们考虑词向量的书写方式。
我们先制作词典(此处词典已经制作完成,我们直接读取就好了,词典制作过程会放在留言中)
我们把词典和我们的数据集进行对应,制作完成一个纯数字的对应码
得到对应码以后进行输出测试是否正确。
数据无误进行填充,把数据码用特殊标签进行替代完成数据长度相同的内容
检验数据长度

数据读取(字典、数据集)

# 字典读取
def get_dict_len(d_path):
    with open(d_path, 'r', encoding='utf-8') as f:
        line = eval(f.readlines()[0])
    return line

word_dict = get_dict_len('新闻文本标签分类/dict.txt')
# 训练集和验证集读取
set = []
def dataset(datapath):  # 数据集读取代码
    with open(datapath)as f:
        for i in f.readlines():
            data = []
            dataset = i[:i.rfind('\t')].split(',')  # 获取文字内容
            dataset = np.array(dataset)
            data.append(dataset)
            label = np.array(i[i.rfind('\t')+1:-1])  # 获取标签
            data.append(label)
            set.append(data)
    return set

train_dataset = dataset('新闻文本标签分类/Train_IDs.txt')
val_dataset = dataset('新闻文本标签分类/Val_IDs.txt')

数据初始化


定义一些需要的值

# 初始数据准备 
vocab_size = len(word_dict) + 1  # 字典长度加1
print(vocab_size)
emb_size = 256  # 神经网络长度
seq_len = 30  # 数据集长度(需要扩充的长度)
batch_size = 32  # 批处理大小
epochs = 2  # 训练轮数
pad_id = word_dict['<unk>']  # 空的填充内容值

nu=["财经","彩票","房产","股票","家居","教育","科技","社会","时尚","时政","体育","星座","游戏","娱乐"]

# 生成句子列表(数据码生成文本)
def ids_to_str(ids):
    # print(ids)
    words = []
    for k in ids:
        w = list(word_dict)[eval(k)]
        words.append(w if isinstance(w, str) else w.decode('ASCII'))
    return " ".join(words)
5308

数据查看


查看数据是否正确如有异常及时修改

# 查看数据内容
for i in  train_dataset:
    sent = i[0]
    label = int(i[1
  • 3
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 7
    评论
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

三岁学编程

感谢支持,更好的作品会继续努力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值