自从Bert刷新了几大经典NLP任务后,企业也逐渐采用Bert----效果拔群、性能感人。。本文主要记录下使用过程。
下载模型BERT-Base,Chiness,clone 代码
数据格式:label \tab query
文本分类时,只需要修改 run_classify.py 文件:
run_classify.py阅读:main()入口
1、初始化加载checkpoint
2、bert_config初始化,save_checkpoints_steps每隔steps保存一次
3、model_fn:model_fn_builder返回model_fn函数(输入feature、label,调create_model得loss、probabilities等,计算acc等并返回)
estimator:传日=入model_fn
4、if FLAGS.do_train :
加载train数据,file_based_convert_examples_to_features()把每行转为feature(input_ids、input_mask、segment_ids、label_id),并写到tf_record文件中。
train_input_fn 传入到estimator中
5、eval、predict过程类似。。。建议先看看estimator文档。。。
简单使用示例:可参考这个
1、定义数据处理类
基类Dataprocessor,定义自己的类 DomainDataProcessor
class DomainDataProcessor(DataProcessor):
"""Processor for the XNLI data set."""
def __init__(self):
self.language = "zh"
self._label = [];
self._label_map = {};
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")),
"dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self, data_dir):
"""See base class."""
if not self._label:
for line in open(os.path.join(data_dir, "label.txt")):
col = line.strip().split('\t')
target_label = col[0]
if len(col) == 2: # 这个是将之前的细标签转为粗标签
target_label = col[1]
if target_label not in self._label:
self._label.append(target_label)
self._label_map[col[0]] = target_label;
return self._label
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[1])
label = tokenization.convert_to_unicode(self._label_map[line[0]])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
2、main()函数内加入:
processors = {
"cola": ColaProcessor,
"mnli": MnliProcessor,
"mrpc": MrpcProcessor,
"xnli": XnliProcessor,
"domain": DomainDataProcessor, # 自己的数据读取类
}
3、执行脚本 run.sh
BERT_BASE_DIR=/home/hadoop-mtai/cephfs/data/crz/model/chinese_L-12_H-768_A-12
DATA_DIR=/opt/crz/bert_exp/Data
OUTPUT_DIR=../output_nodict
CUDA_VISIBLE_DEVICES=0 python run_classifier.py \
--task_name=domain \
--do_train=true \
--do_predict=true \
--do_eval=true \
--data_dir=$DATA_DIR/ \
--vocab_file=$BERT_BASE_DIR/vocab.txt \
--bert_config_file=$BERT_BASE_DIR/bert_config.json \
--init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
--max_seq_length=30 \
--train_batch_size=32 \
--learning_rate=2e-5 \
--num_train_epochs=5.0 \
--output_dir=$OUTPUT_DIR \
几个修改:TODO 代码上传。。。
1、模型输出
2、Loss函数
3、在每个checkpoint上验证、预测
4、保存模型 savemodel
5、冻结bert参数
6、加入其他特征