彻底弄懂bert文本分类源码

Bert & Transformer文本分类源码详解

参考论文

https://arxiv.org/abs/1706.03762

https://arxiv.org/abs/1810.04805

在本文中,我将以run_classifier.py以及MRPC数据集为例介绍关于bert以及transformer的源码,官方代码基于tensorflow-gpu 1.x,若为tensorflow 2.x版本,会有各种错误,建议切换版本至1.14。

当然,注释好的源代码在这里

章节

Demo传参

首先大家拿到这个模型,管他什么原理,肯定想跑起来看看结果,至于预训练模型以及数据集下载。任何时候应该先看官方教程,官方代表着权威,更容易实现,如果遇到问题可以去issues和stackoverflow看看,再辅以中文教程,一般上手就不难了,这里就不再赘述了。

先从Flags参数讲起,到如何跑通demo。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-O3f8ctQ6-1618457232403)(https://github.com/sherlcok314159/ML/blob/main/nlp/Images/flags.png)]

拿到源码不要慌张,英文注释往往起着最关键的作用,另外阅读源码详细技巧可以看源码技巧

"Required Parameters"意思是必要参数,你等会执行时必须向程序里面传的参数。

export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12
export GLUE_DIR=/path/to/glue

python run_classifier.py \
  --task_name=MRPC \
  --do_train=true \
  --do_eval=true \
  --data_dir=$GLUE_DIR/MRPC \
  --vocab_file=$BERT_BASE_DIR/vocab.txt \
  --bert_config_file=$BERT_BASE_DIR/bert_config.json \
  --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
  --max_seq_length=128 \
  --train_batch_size=32 \
  --learning_rate=2e-5 \
  --num_train_epochs=3.0 \
  --output_dir=/tmp/mrpc_output/

这是官方给的示例,这个将两个文件夹加入了系统路径,本人Ubuntu18.04加了好像也找不到,所以建议将那些文件路径改为绝对路径

task_name --> 这次任务的名称
do_train --> 是否做fine-tune
do_eval --> 是否交叉验证
do_predict --> 是否做预测
data_dir --> 数据集的位置
vocab_dir --> 词表的位置(一般bert模型下好就能找到) 
bert_config --> bert模型参数设置
init_checkpoint --> 预训练好的模型
max_seq_length 
  • 5
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值