BERT学习笔记:run_classifier.py

BERT 源码初探之 run_classifier.py

本文源码来源于 Github上的BERT 项目中的 run_classifier.py 文件。阅读本文需要对Attention Is All You Need以及BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding两篇论文有所了解,以及部分关于深度学习自然语言处理Tensorflow的储备知识。

0 前言

1 简介

2 classifier使用方法

2.1 参数设置

2.1.1 必须参数

源码中提到的必须要设置的参数如下所示:

flags.DEFINE_string(
    "data_dir", None,
    "The input data dir. Should contain the .tsv files (or other data files) "
    "for the task.")
  • 名称:data_dir
  • 默认值: None
  • 注释:输入数据的目录,需要包含任务所需的 .csv 格式(或者其他类型)的文件。
flags.DEFINE_string(
    "bert_config_file", None,
    "The config json file corresponding to the pre-trained BERT model. "
    "This specifies the model architecture.")
  • 名称:bert_config_file
  • 默认值:None
  • 注释:和预训练过的bert模型对应的json格式的配置文件,这指定了模型的架构。
flags.DEFINE_string("task_name", None, "The name of the task to train.")
  • 名称:task_name
  • 默认值:None
  • 注释:要训练的任务名称。
flags.DEFINE_string("vocab_file", None,
                    "The vocabulary file that the BERT model was trained on.")
  • 名称:vocab_file
  • 默认值:None
  • 注释:BERT模型训练所用的词典文件
flags.DEFINE_string(
    "output_dir", None,
    "The output directory where the model checkpoints will be written.")
  • 名称:output_dir
  • 默认值:None
  • 注释:模型checkpoint会被保存到的路径位置。
2.1.2 其他参数
flags.DEFINE_string(
    "init_checkpoint", None,
    "Initial checkpoint (usually from a pre-trained BERT model).")
  • 名称:init_checkpoint
  • 默认值:None
  • 注释:初始的checkpoint(通常来源于一个预训练过的BERT模型)。
flags.DEFINE_bool(
    "do_lower_case", True,
    "Whether to lower case the input text. Should be True for uncased "
    "models and False for cased models.")
  • 名称:do_lower_case
  • 默认值:True
  • 注释:是否把输入的文本全部小写。对于所有的 uncased models(不保留大小写以及重音标记)应该设为True,而对于 cased models(保留大小写以及重音标记,如果你认为这对于你的训练任务是有益的)则应该设为False
flags.DEFINE_integer(
    "max_seq_length", 128,
    "The maximum total input sequence length after WordPiece tokenization. "
    "Sequences longer than this will be truncated, and sequences shorter "
    "than this will be padded.")
  • 名称:max_seq_length
  • 默认值:128
  • 注释:WordPiece tokenization处理之后的最大输入序列长度。比这个长度长的输入序列将被阶段,而比这个短的输入序列将被补齐。
flags.DEFINE_bool("do_train", False, "Whether to run training.")
  • 名称:do_train
  • 默认值:False
  • 注释:是否运行训练
flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.")
  • 名称:do_eval
  • 默认值:False
  • 注释:是否在验证集上进行计算(eval)。
flags.DEFINE_bool(
    "do_predict", False,
    "Whether to run the model in inference mode on the test set.")
  • 名称:do_predict
  • 默认值:False
  • 注释:是否在测试集上用推断模式来运行模型。
flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
  • 名称:train_batch_size
  • 默认值:32
  • 注释:训练过程中的批尺寸。
flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.")
  • 名称:eval_batch_size
  • 默认值:8
  • 注释:验证过程中的批尺寸。
flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predict.")
  • 名称:predict_batch_size
  • 默认值:8
  • 注释:预测过程中的批尺寸。
flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
  • 名称:learning_rate
  • 默认值:5e-5
  • 注释:Adam算法的初始学习率。
flags.DEFINE_float("num_train_epochs", 3.0,
                   "Total number of training epochs to perform.")
  • 名称:num_train_epochs
  • 默认值:3.0
  • 注释:完成训练所需要遍历训练数据集的次数
flags.DEFINE_float(
    "warmup_proportion", 0.1,
    "Proportion of training to perform linear learning rate warmup for. "
    "E.g., 0.1 = 10% of training.")
  • 名称:warmup_proportion
  • 默认值:0.1
  • 注释:用于warm up的训练步数比例。其中0.1表示训练步数的10%。
flags.DEFINE_integer("save_checkpoints_steps", 1000,
                     "How often to save the model checkpoint.")
  • 名称:save_checkpoints_steps
  • 默认值:1000
  • 注释:多久保存一次模型
flags.DEFINE_integer("iterations_per_loop", 1000,
                     "How many steps to make in each estimator call.")
  • 名称:iterations_per_loop
  • 默认值:1000
  • 注释:每隔多少步来调用一次估计函数。
2.1.3 TPU相关参数
flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
  • 名称:use_tpu
  • 默认值:False
  • 注释:是否使用TPU
tf.flags.DEFINE_string(
    "tpu_name", None,
    "The Cloud TPU to use for training. This should be either the name "
    "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
    "url.")
  • 名称:tpu_name
  • 默认值:None
  • 注释:用于训练的云TPU,应该是创建云TPU时使用的名称,或者是类似 grpc://ip.address.of.tpu:8470 这种类型的URL。
tf.flags.DEFINE_string(
    "tpu_zone", None,
    "[Optional] GCE zone where the Cloud TPU is located in. If not "
    "specified, we will attempt to automatically detect the GCE project from "
    "metadata.")
  • 名称:tpu_zone
  • 默认值:None
  • 注释:【可选】 云TPU所在的GCE区域,如果没有指定该参数,程序将会自动在元数据中检测。
tf.flags.DEFINE_string(
    "gcp_project", None,
    "[Optional] Project name for the Cloud TPU-enabled project. If not "
    "specified, we will attempt to automatically detect the GCE project from "
    "metadata.")
  • 名称:gcp_project
  • 默认值:None
  • 注释:【可选】 云TPU项目的项目名称,如果不指定程序将自动从元数据中检测。
tf.flags.DEFINE_string("master", None, "[Optional
  • 21
    点赞
  • 58
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值