官方文档:https://github.com/brightmart/roberta_zh
步骤
- 复制项目: git clone https://github.com/brightmart/roberta_zh
- 下载中文预训练模型:https://drive.google.com/open?id=1ykENKV7dIFAqRRQbZIh0mSb7Vjc2MeFA,解压到项目外层
- 和 run_classifier.py 同一层建立 model(存放微调后得到的新模型)、data(存放训练、测试文件)文件夹
- 修改 run_classifier.py 中的文件处理 Class,修改Flag(如data_dir等)
- nohup python -u run_classifier.py > run_classifier.log 2>&1
遇到的坑
L12 pytorch版本的,在后面运行会报错,因为是 model.bin 而不是 ckpt
如果用工程的 run_classifier.py 运行,应该要改成这样
export BERT_BASE_DIR=./model/roberta_zh_l12
export MY_DATA_DIR=./data/lcqmc
python run_classifier.py \
--task_name=lcqmc_pair \
--do_train=true \
--do_eval=true \
--data_dir=$MY_DATA_DIR \
--vocab_file=$BERT_BASE_DIR/vocab.txt \
--bert_config_file=$BERT_BASE_DIR/bert_config.json \
--init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
--max_seq_length=128 \
--train_batch_size=64 \
--learning_rate=2e-5 \
--num_train_epochs=3 \
--output_dir=./checkpoint_lcqmc
numpy 版本问题参考:https://blog.csdn.net/qq_15694045/article/details/100577784