1.SimpleTransformers
基于transformers库开发的,快速实现模型调用、训练和预测的工具
1.1 支持的任务
- 文本分类
- 多分类
- 句子对应多标签
- 实体识别Token Classification (NER)
- 阅读理解Question Answering
1.2 安装
conda or pip(官方推荐conda)
依赖:pandas、torch
pip install torch==1.9.0+cpu torchvision==0.10.0+cpu torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
pip install simpletransformers
1.3 参数
创建模型的过程中需要传入参数类,参数分为comment和spicial,列举重要的参数及含义
参数 | 含义 |
---|---|
output_dir | 输出结果位置 |
tensorboard_dir | tensorboard输出地址 |
best_model_dir | 当evaluate_during_training为true,保存最好模型的位置 |
evaluate_during_training | train与evaluation同时进行 |
cache_dir | 缓存文件位置 |
save_model_every_epoch | 每一次迭代储存一次model checkpoint |
2.多分类任务
项目快速实现对文本的多分类与API接口
数据准备
- DataFrame,text and lables
- lables必须从0开始,映射为整数数字
- KFold划分样本集
模型训练
模型调用
3.遇到的问题
ImportError: dlopen: cannot load any more object with static TLS
解决:调整torch、sklearn、simpletransformers的导入顺序