NLP文本分类pytorch框架-支持Bert等预训练语言模型

基本信息

基于Pytorch的文本分类框架。

同时支持中英文的数据集的文本分类任务。

项目地址:https://github.com/wzzzd/text_classifier_pytorch

支持的模型

  • 非预训练类模型:
    • FastText
    • TextCNN
    • TextRNN
    • TextRCNN
    • Transformer
  • 预训练类模型
    • Bert
    • Albert
    • Roberta
    • Distilbert
    • Electra
    • XLNet

支持的训练方式

  • 支持中英文语料训练
    • 支持中英文的文本分类任务。
  • 支持多种模型使用
    • 配置文件Config.py中的变量model_name表示模型名称,可以更改成你想要加载的模型名称。
    • 若是预训练类的模型,如Bert等,需要同步修改变量initial_pretrain_modelinitial_pretrain_tokenizer,修改为你想要加载的预训练参数。
  • 混合精度训练
    • 用于提升训练过程效率,缩短训练时间。
    • 配置文件Config.py中的变量fp16值改为True
  • GPU多卡训练
    • 用于分布式训练,支持单机单卡、多卡训练。
    • 配置文件Config.py中的变量cuda_visible_devices用于设置可见的GPU卡号,多卡情况下用,间隔开。
  • 对抗训练
    • 在模型embedding层增加扰动,使模型学习对抗扰动,提升表现,需要额外增加训练时间。
    • 配置文件Config.py中的变量adv_option用于设置可见的对抗模式,目前支持FGM/PGD。
  • 对比学习
    • 用于增强模型语义特征提取能力,借鉴Rdrop和SimCSE的思想,目前支持KL loss和InfoNCE两种损失。
    • 配置文件Config.py中的变量cl_option设置为True则表示开启对比学习模式,cl_method用于设置计算对比损失的方法。

数据集

  • THUCNews

    • 来自:https://github.com/649453932/Chinese-Text-Classification-Pytorch
    • 关于THUCNews的的数据。
    • 数据分为10个类标签类别,分别为:财经、房产、股票、教育、科技、社会、时政、体育、游戏、娱乐
  • 加入自己的数据集

    • 可使用本项目的处理方式,将数据集切分为3部分:train/valid/test,其中token和label之间用制表符\t分割。
    • 在 ./dataset 目录下新建一个文件夹,并把3个数据文件放置新建文件夹下。
  • 数据集示例

    • 以数据集THUCNews为栗子,文本和标签使用空格隔开,采用以下形式存储:
        午评沪指涨0.78%逼近2800 汽车家电农业领涨	2
        卡佩罗:告诉你德国脚生猛的原因 不希望英德战踢点球	7
    

实验

说明:预训练模型基于transformers框架,如若想要替换成其他预训练参数,可以查看transformers官方网站

模型名称MicroF1LearningRate预训练参数
FastText0.89261e-3-
TextCNN0.90091e-3-
TextRNN0.90801e-3-
TextRCNN0.91421e-3-
Tramsformer(2 layer)0.88491e-3-
Albert0.91242e-5voidful/albert_chinese_tiny
Distilbert0.92092e-5Geotrend/distilbert-base-zh-cased
Bert0.94012e-5bert-base-chinese
Roberta0.94482e-5hfl/chinese-roberta-wwm-ext
Electra0.93772e-5hfl/chinese-electra-base-discriminator
XLNet0.90512e-5无参数初始化

环境配置

Python使用的是3.6.X版本,其他依赖模块如下:

    numpy==1.19.2
    pandas==1.1.5
    scikit_learn==1.0.2
    torch==1.8.0
    tqdm==4.62.3
    transformers==4.15.0
    apex==0.1

除了apex需要额外安装(参考官网:https://github.com/NVIDIA/apex
),其他模块可通过以下命令安装依赖包

    pip install -r requirement.txt

如何使用项目代码

1. 训练

准备好训练数据后,终端可运行命令

    python3 main.py

2 测试评估

加载已训练好的模型,并使用valid set作模型测试,输出文件到 ./dataset/${your_dataset}/output/output.txt 目录下。

需要修改Config文件中的变量值mode = 'test',并保存。

终端可运行命令

    python3 main.py

参考

[Github:transformers] https://github.com/huggingface/transformers

[Paper:Bert] https://arxiv.org/abs/1810.04805

[Paper:RDrop] https://arxiv.org/abs/2106.14448

[Paper:SimCSE] https://arxiv.org/abs/2104.08821

  • 2
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值