文本分类:3、新秀Bert文本分类

自从Bert刷新了几大经典NLP任务后,企业也逐渐采用Bert----效果拔群、性能感人。。本文主要记录下使用过程。
下载模型BERT-Base,Chiness,clone 代码
数据格式:label   \tab   query

文本分类时,只需要修改 run_classify.py 文件:
run_classify.py阅读:main()入口
1、初始化加载checkpoint
2、bert_config初始化,save_checkpoints_steps每隔steps保存一次
3、model_fn:model_fn_builder返回model_fn函数(输入feature、label,调create_model得loss、probabilities等,计算acc等并返回)
estimator:传日=入model_fn
4、if FLAGS.do_train :
    加载train数据,file_based_convert_examples_to_features()把每行转为feature(input_ids、input_mask、segment_ids、label_id),并写到tf_record文件中。
    train_input_fn 传入到estimator中
5、eval、predict过程类似。。。建议先看看estimator文档。。。

简单使用示例可参考这个
1、定义数据处理类
基类Dataprocessor,定义自己的类 DomainDataProcessor

class DomainDataProcessor(DataProcessor):
  """Processor for the XNLI data set."""
  def __init__(self):
    self.language = "zh"
    self._label = [];
    self._label_map = {};
  def get_train_examples(self, data_dir):
    """See base class."""
    return self._create_examples(
        self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
  def get_dev_examples(self, data_dir):
    """See base class."""
    return self._create_examples(
        self._read_tsv(os.path.join(data_dir, "dev.tsv")),
        "dev")
  def get_test_examples(self, data_dir):
    """See base class."""
    return self._create_examples(
        self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
  def get_labels(self, data_dir):
    """See base class."""
    if not self._label:
         for line in open(os.path.join(data_dir, "label.txt")):
            col = line.strip().split('\t')
            target_label = col[0]
            if len(col) == 2:  # 这个是将之前的细标签转为粗标签
                target_label = col[1]
            if target_label not in self._label:
                self._label.append(target_label)
            self._label_map[col[0]] = target_label;

    return self._label
  def _create_examples(self, lines, set_type):
    """Creates examples for the training and dev sets."""
    examples = []
    for (i, line) in enumerate(lines):
      guid = "%s-%s" % (set_type, i)
      text_a = tokenization.convert_to_unicode(line[1])
      label = tokenization.convert_to_unicode(self._label_map[line[0]])
      examples.append(
          InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
    return examples

2、main()函数内加入:

processors = {
      "cola": ColaProcessor,
      "mnli": MnliProcessor,
      "mrpc": MrpcProcessor,
      "xnli": XnliProcessor,
      "domain": DomainDataProcessor,  # 自己的数据读取类
  }

3、执行脚本 run.sh

BERT_BASE_DIR=/home/hadoop-mtai/cephfs/data/crz/model/chinese_L-12_H-768_A-12
DATA_DIR=/opt/crz/bert_exp/Data
OUTPUT_DIR=../output_nodict
CUDA_VISIBLE_DEVICES=0 python run_classifier.py \
 --task_name=domain \
 --do_train=true \
 --do_predict=true \
 --do_eval=true \
 --data_dir=$DATA_DIR/ \
 --vocab_file=$BERT_BASE_DIR/vocab.txt \
 --bert_config_file=$BERT_BASE_DIR/bert_config.json \
 --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
 --max_seq_length=30 \
 --train_batch_size=32 \
 --learning_rate=2e-5 \
 --num_train_epochs=5.0 \
 --output_dir=$OUTPUT_DIR \

几个修改:TODO 代码上传。。。
1、模型输出
2、Loss函数
3、在每个checkpoint上验证、预测
4、保存模型 savemodel
5、冻结bert参数
6、加入其他特征

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值