文章名 BERT Loses Patience: Fast and Robust Inference with Early Exithttps://arxiv.org/abs/2006.04152作者Wangchunshu Zhou,(NeurIPS 2020)CCF A会议
新手,啥也不会,transformers也是第一次搞,做个记录。
一 首先从github下载code,然后解压,clone也行。
二 下载语料(数据)
参考Bert系列(一)——demo运行 - 简书 (jianshu.com)https://www.jianshu.com/p/3d0bb34c488a
然而,我怎么搞都失败,
Error downloading standard development IDs for MRPC. You will need to manually split your data.
后面,还是直接参考下面老哥的,用百度网盘下载
requirements.txt 那边只写了transformers==3.5.1
还需要
tensorboardX
TypeError: __init__() got an unexpected keyword argument 'serialized_options'
pip install protobuf==3.4.0
ImportError: cannot import name ‘SAVE_STATE_WARNING‘ from ‘torch.optim.lr_sc
torch1.8.0 换成torch 1.7.1
scikit-learn
pip install -U scikit-learn
python 3.6.2
四 运行
export GLUE_DIR=/path/to/glue_data
export TASK_NAME=MRPC
python ./run_glue_with_pabee.py \
--model_type albert \
--model_name_or_path bert-base-uncased/albert-base-v2 \
--task_name $TASK_NAME \
--do_train \
--do_eval \
--do_lower_case \
--data_dir "$GLUE_DIR/$TASK_NAME" \
--max_seq_length 128 \
--per_gpu_train_batch_size 32 \
--per_gpu_eval_batch_size 32 \
--learning_rate 2e-5 \
--save_steps 50 \
--logging_steps 50 \
--num_train_epochs 5 \
--output_dir /path/to/save/ \
--evaluate_during_training
第一行表示添加环境变量,glue_data就是第二步得到的语料数据,我直接放在当前目录./
export GLUE_DIR=./glue_data
第二行照抄就行,任务名是MRPC
第三行 运行训练脚本
第四行开始参数设置
-- model_type是模型类别 albert就是一种, bert也是一种
-- model名字或路径,albert-base-v2这是albert的一种模型名字,据说也可以写路径,路径上要放自己下载好的模型。输入模型名字,有的可以自动下载模型
--data_dir 语料数据保存的路径,指第二步下载的,glue_data/MRPC
接下来是一堆参数
--output_dir 训练结果保存的路径
这是我的输入
python ./run_glue_with_pabee.py --model_type albert --model_name_or_path albert-base-v2 --task_name MRPC --do_train --do_eval --do_lower_case --data_dir ./glue_data/MRPC --max_seq_length 128 --per_gpu_train_batch_size 32 --per_gpu_eval_batch_size 32 --learning_rate 2e-5 --save_steps 50 --logging_steps 50 --num_train_epochs 5 --output_dir ./outmodel --evaluate_during_training
结果,一大堆警告,好歹有个结果
/nfs/home/***/.conda/envs/PABEE/lib/python3.6/site-packages/transformers/data/metrics/__init__.py:66: FutureWarning: This metric will be removed from the library soon, metrics should be handled with the \U0001f917 Datasets library. You can have a look at this example script for pointers: https://github.com/huggingface/transformers/blob/master/examples/text-classification/run_glue.py
warnings.warn(DEPRECATION_WARNING, FutureWarning)
/nfs/home/***/.conda/envs/PABEE/lib/python3.6/site-packages/transformers/data/metrics/__init__.py:42: FutureWarning: This metric will be removed from the library soon, metrics should be handled with the \U0001f917 Datasets library. You can have a look at this example script for pointers: https://github.com/huggingface/transformers/blob/master/examples/text-classification/run_glue.py
warnings.warn(DEPRECATION_WARNING, FutureWarning)
/nfs/home/huanghaiyang/.conda/envs/PABEE/lib/python3.6/site-packages/transformers/data/metrics/__init__.py:36: FutureWarning: This metric will be removed from the library soon, metrics should be handled with the \U0001f917 Datasets library. You can have a look at this example script for pointers: https://github.com/huggingface/transformers/blob/master/examples/text-classification/run_glue.py
warnings.warn(DEPRECATION_WARNING, FutureWarning)
11/26/2021 15:12:13 - INFO - __main__ - ***** Eval results *****
11/26/2021 15:12:13 - INFO - __main__ - acc = 0.8480392156862745
acc = 0.8480392156862745
11/26/2021 15:12:13 - INFO - __main__ - acc_and_f1 = 0.8696336429308567
acc_and_f1 = 0.8696336429308567
11/26/2021 15:12:13 - INFO - __main__ - f1 = 0.8912280701754387
f1 = 0.8912280701754387