基于PaddleNLP的文本多分类任务

此项目是基于PaddleNLP的文本多分类任务,包含外交、军事、经济、政治、科技、安全6大领域。

1、代码结构

multi_class/
├── few-shot # 小样本学习方案
├── retrieval_based # 语义索引方案
├── analysis # 分析模块
├── deploy # 部署
│   ├── simple_serving # SimpleServing服务化部署
│   └── triton_serving # Triton服务化部署
├── train.py # 训练、评估、裁剪脚本
├── utils.py # 工具函数脚本
└── README.md # 多分类使用说明

2、环境构建

在这里插入图片描述

2.1、环境安装

conda create -n UTC_multi_class python=3.7 pip=21.1.1
conda activate UTC_multi_class
python3 -m  pip install scikit-learn==1.0.2
pip install paddlenlp==2.5.1
python -m pip install paddlepaddle-gpu==2.3.post112 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html

3、数据准备

训练需要准备指定格式的本地数据集,如果没有已标注的数据集,可以参考文本分类任务doccano数据标注使用指南(https://github.com/PaddlePaddle/PaddleNLP/blob/develop/applications/text_classification/doccano.md)进行文本分类数据标注。指定格式本地数据集目录结构:

data/
├── train.txt # 训练数据集文件
├── dev.txt # 开发数据集文件
└── label.txt # 分类标签文件

训练、开发、测试数据集 文件中文本与标签类别名用tab符’\t’分隔开,文本中避免出现tab符’\t’。

  • train.txt/dev.txt/test.txt 文件格式:
<文本>'\t'<标签>
<文本>'\t'<标签>
...	
  • train.txt/dev.txt/test.txt 文件样例:
......
中新网4月29日电据法新社报道,巴基斯坦军方29日表示,巴基斯坦当天成功试射了一枚可以装载核弹头的弹道导弹,其射程达到2000公里。据悉,此次试射的是哈塔夫-6型地对地弹道导弹(沙欣2号),2005年3月巴基斯坦曾经试射过一次该型导弹。巴基斯坦军方的声明表示,第二次试射的哈塔夫-6型导弹比去年要增加了更多的技术含量。哈塔夫-6型导弹依靠二级固体燃料火箭运载,可以装载常规弹头和核弹头,并具有高打击精度。巴基斯坦总理阿齐兹参加观看了导弹试射的全过程,但是发射地点暂时没有对外公布。试射成功后,阿齐兹向参与导弹发射工作的科学家、工程师和工作人员表示了祝贺,并表示巴基斯坦的战略规划将“从强大走向更加强大”。哈塔夫-6型导弹是巴基斯坦射程最远的弹道导弹,而且具有发展到2500公里射程的潜力。(章田/雅龙)    军事
英国简氏防务网站2006年3月13日报道沙特阿拉伯装甲部队计划于2006年4月对其MBT-2000型“阿尔卡哈利德”(Al-Khalid)主战坦克(MBT)进行测试。该坦克由巴基斯坦塔克希拉(Taxila)重型工业公司生产。巴基斯坦国防部表示,如果本次测试成功,那么沙特阿拉伯将可能采购150辆MBT-2000主战坦克,总价值将达6亿美元,这将是巴基斯坦有史以来最大的一笔单项出口合同。该坦克的生产对于巴基斯坦国防工业而言,也是一个非常重要的里程碑,“阿尔卡哈利德”主战坦克的性能能够与目前世界上任何先进坦克相媲美,在重量、火力、机动性、全天候作战能力等方面都比较突出。“哈利德”主战坦克采用传统设计,即驾驶舱位于车前方,炮塔位于中间,而动力装置则放在最后边。该坦克的车体和炮塔均采用采用全焊接刚装甲,车体前弧面装甲外加挂了一层复合装甲,如果需要,也可披挂爆炸反应装甲。据估计,炮塔正面厚度为600毫米,侧面/前方突出部位的厚度为450~470毫米。    军事
......

分类标签

label.txt(分类标签文件)记录数据集中所有标签集合,每一行为一个标签名。

  • label.txt 文件格式:
<标签>
<标签>
...
  • label.txt 文件样例:
外交
军事
经济
政治
科技
安全

4、模型训练

我们推荐使用 Trainer API 对模型进行微调。只需输入模型、数据集等就可以使用 Trainer API 高效快速地进行预训练、微调和模型压缩等任务,可以一键启动多卡训练、混合精度训练、梯度累积、断点重启、日志显示等功能,Trainer API 还针对训练过程的通用训练配置做了封装,比如:优化器、学习率调度等。

4.1、预训练模型微调

使用CPU/GPU训练,默认为GPU训练。使用CPU训练只需将设备参数配置改为–device cpu,可以使用–device gpu:0指定GPU卡号:

python train.py --do_train --do_eval --do_export --model_name_or_path ernie-3.0-tiny-medium-v2-zh --output_dir checkpoint --device gpu --num_train_epochs 100 --early_stopping True --early_stopping_patience 5 --learning_rate 3e-5 --max_length 512 --per_device_eval_batch_size 32 --per_device_train_batch_size 32 --metric_for_best_model accuracy --load_best_model_at_end --logging_steps 5 --evaluation_strategy epoch --save_strategy epoch --save_total_limit 1      

如果在GPU环境中使用,可以指定gpus参数进行多卡分布式训练。使用多卡训练可以指定多个GPU卡号,例如 --gpus 0,1。如果设备只有一个GPU卡号默认为0,可使用nvidia-smi命令查看GPU使用情况:

python -m paddle.distributed.launch --gpus 0,2,5 train.py --do_train --do_eval --do_export --model_name_or_path ernie-3.0-tiny-medium-v2-zh --output_dir checkpoint --device gpu --num_train_epochs 100 --early_stopping True --early_stopping_patience 5 --learning_rate 3e-5 --max_length 512 --per_device_eval_batch_size 32 --per_device_train_batch_size 32 --metric_for_best_model accuracy --load_best_model_at_end --logging_steps 5 --evaluation_strategy epoch --save_strategy epoch --save_total_limit 1

中文训练任务(文本支持含部分英文)推荐使用"ernie-1.0-large-zh-cw"、“ernie-3.0-tiny-base-v2-zh”、“ernie-3.0-tiny-medium-v2-zh”、“ernie-3.0-tiny-micro-v2-zh”、“ernie-3.0-tiny-mini-v2-zh”、“ernie-3.0-tiny-nano-v2-zh”、“ernie-3.0-tiny-pico-v2-zh”。
英文训练任务推荐使用"ernie-3.0-tiny-mini-v2-en"、 “ernie-2.0-base-en”、“ernie-2.0-large-en”。
英文和中文以外语言的文本分类任务,推荐使用基于96种语言(涵盖法语、日语、韩语、德语、西班牙语等几乎所有常见语言)进行预训练的多语言预训练模型"ernie-m-base"、“ernie-m-large”。

主要的配置的参数为:

  • do_train: 是否进行训练。
  • do_eval: 是否进行评估。
  • debug: 与do_eval配合使用,是否开启debug模型,对每一个类别进行评估。
  • do_export: 训练结束后是否导出静态图。
  • do_compress: 训练结束后是否进行模型裁剪。
  • model_name_or_path: 内置模型名,或者模型参数配置目录路径。默认为ernie-3.0-tiny-medium-v2-zh。
  • output_dir: 模型参数、训练日志和静态图导出的保存目录。
  • device: 使用的设备,默认为gpu。
  • num_train_epochs: 训练轮次,使用早停法时可以选择100。
  • early_stopping: 是否使用早停法,也即一定轮次后评估指标不再增长则停止训练。
  • early_stopping_patience: 在设定的早停训练轮次内,模型在开发集上表现不再上升,训练终止;默认为4。
  • learning_rate: 预训练语言模型参数基础学习率大小,将与learning rate scheduler产生的值相乘作为当前学习率。
  • max_length: 最大句子长度,超过该长度的文本将被截断,不足的以Pad补全。提示文本不会被截断。
  • per_device_train_batch_size: 每次训练每张卡上的样本数量。可根据实际GPU显存适当调小/调大此配置。
  • per_device_eval_batch_size: 每次评估每张卡上的样本数量。可根据实际GPU显存适当调小/调大此配置。
  • max_length: 最大句子长度,超过该长度的文本将被截断,不足的以Pad补全。提示文本不会被截断。
  • train_path: 训练集路径,默认为"./data/train.txt"。
  • dev_path: 开发集集路径,默认为"./data/dev.txt"。
  • test_path: 测试集路径,默认为"./data/dev.txt"。
  • label_path: 标签路径,默认为"./data/label.txt"。
  • bad_case_path: 错误样本保存路径,默认为"./data/bad_case.txt"。
  • width_mult_list:裁剪宽度(multi head)保留的比例列表,表示对self_attention中的 q、k、v 以及 ffn 权重宽度的保留比例,保留比例乘以宽度(multi haed数量)应为整数;默认是None。 训练脚本支持所有TrainingArguments的参数,更多参数介绍可参考TrainingArguments 参数介绍。
    程序运行时将会自动进行训练,评估。同时训练过程中会自动保存开发集上最佳模型在指定的 output_dir 中,保存模型文件结构如下所示:
checkpoint/
├── export # 静态图模型
├── config.json # 模型配置文件
├── model_state.pdparams # 模型参数文件
├── tokenizer_config.json # 分词器配置文件
├── vocab.txt
└── special_tokens_map.json

NOTE:

  • 中文训练任务(文本支持含部分英文)推荐使用"ernie-1.0-large-zh-cw"、“ernie-3.0-tiny-base-v2-zh”、“ernie-3.0-tiny-medium-v2-zh”、“ernie-3.0-tiny-micro-v2-zh”、“ernie-3.0-tiny-mini-v2-zh”、“ernie-3.0-tiny-nano-v2-zh”、“ernie-3.0-tiny-pico-v2-zh”。
  • 英文训练任务推荐使用"ernie-3.0-tiny-mini-v2-en"、 “ernie-2.0-base-en”、“ernie-2.0-large-en”。
  • 英文和中文以外语言的文本分类任务,推荐使用基于96种语言(涵盖法语、日语、韩语、德语、西班牙语等几乎所有常见语言)进行预训练的多语言预训练模型"ernie-m-base"、“ernie-m-large”,详情请参见ERNIE-M论文。

4.2、训练评估

训练后的模型我们可以开启debug模式,对每个类别分别进行评估,并打印错误预测样本保存在bad_case.txt。默认在GPU环境下使用,在CPU环境下修改参数配置为–device “cpu”:

python train.py --do_eval --debug True --device gpu --model_name_or_path checkpoint --output_dir checkpoint --per_device_eval_batch_size 32 --max_length 512 --test_path './data/dev.txt'

4.3、结果展示

[{'predictions': [{'label': '外交', 'score': 0.931406241805932}], 'text': '9月15日,在国务院新闻办公室举办的新闻发布会上,国家统计局新闻发言人、国民经济综合
统计司司长付凌晖表示,8月份在一系列扩大内需、提振信心、防范风险的政策举措作用下,工业和服务业生产加快,国内需求继续扩大,就业物价形势向好,积极因素累积增多,国民经济延续恢复态势,发展质量稳步提高。民生银行首席经济学家温彬对《证券日报》记者表示,推动经济企稳回升的线索主要有三条:一是需求回暖。内需主要受居民消费反弹驱动,外需受全球经济的韧性支撑。二是预期改善。在一系列稳增长政策出台之后,市场信心有所提振,投资意愿开始反弹。三是价格回升。8月份CPI结束负增长,PPI降幅连续收窄,物价回暖有助于改善企业利润状况,提升投资意愿。'}]
  • 30
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值