文档地址
https://github.com/google-research/bert
准备工作
## Required parameters
## 数据集:包含训练数据集、验证数据集和预测数据集(附件可下载)
flags.DEFINE_string(
"data_dir", "数据文件/tmp",
"The input data dir. Should contain the .tsv files (or other data files) "
"for the task.")
## bert的配置文件(附件可下载)
flags.DEFINE_string(
"bert_config_file", "数据文件/chinese_L-12_H-768_A-12/bert_config.json",
"The config json file corresponding to the pre-trained BERT model. "
"This specifies the model architecture.")
## 启动的任务,最关键的代码块,重写DataProcessor,主要是数据集的处理工作
flags.DEFINE_string("task_name", "csv", "The name of the task to train.")
## 字典集
flags.DEFINE_string("vocab_file", "数据文件/chinese_L-12_H-768_A-12/vocab.txt",
"The vocabulary file that the BERT model was trained on.")
## 输出目录
flags.DEFINE_string(
"output_dir", "output",
"The output directory where the model checkpoints will be written.")
## Other parameters
## 基本预训练模型
flags.DEFINE_string(
"init_checkpoint", "数据文件/chinese_L-12_H-768_A-12/bert_model.ckpt",
"Initial checkpoint (usually from a pre-trained BERT model).")
## 是否对输入的文本小写
flags.DEFINE_bool(
"do_lower_case", True,
"Whether to lower case the input text. Should be True for uncased "
"models and False for cased models.")
## 最大的序列长度
flags.DEFINE_integer(
"max_seq_length", 128,
"The maximum total input sequence length after WordPiece tokenization. "
"Sequences longer than this will be truncated, and sequences shorter "
"than this will be padded.")
## 是否训练
flags.DEFINE_bool("do_train", True, "Whether to run training.")
## 是否验证
flags.DEFINE_bool("do_eval", True, "Whether to run eval on the dev set.")
## 是否预测
flags.DEFINE_bool(
"do_predict", True,
"Whether to run the model in inference mode on the test set.")
关键实现代码
bert自带文本分类run_classifier.py,新增一个实现DataProcessor的数据集处理类即可处理自己想要处理的数据,然后加入到处理器即可。
处理器
processors = {
"cola": ColaProcessor,
"mnli": MnliProcessor,
"mrpc": MrpcProcessor,
"xnli": XnliProcessor,
"csv": CsvProcessor
}
基本的数据集处理器DataProcessor
class DataProcessor(object):
"""Base class for data converters for sequence classification data sets."""
def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set."""
raise NotImplementedError()
def get_dev_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the dev set."""
raise NotImplementedError()
def get_test_examples(self, data_dir):
"""Gets a collection of `InputExample`s for prediction."""
raise NotImplementedError()
def get_labels(self):
"""Gets the list of labels for this data set."""
raise NotImplementedError()
@classmethod
def _read_tsv(cls, input_file, quotechar=None):
"""Reads a tab separated value file."""
with tf.gfile.Open(input_file, "r") as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
lines = []
for line in reader:
lines.append(line)
return lines
自定义数据集处理器CsvProcessor
class CsvProcessor(DataProcessor):
"""Processor for the CoLA data set (GLUE version)."""
## 获取训练数据集
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.csv")), "train")
## 获取验证数据集
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.csv")), "dev")
## 获取预测数据集
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.csv")), "test")
## 分类的数组
## 以train.csv为例子,训练模型的时候主要有"O"和"CS"两大类
def get_labels(self):
"""See base class."""
return ["O", "CS"]
## 具体的数据处理逻辑
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
# i是每一行的索引
# line是每一行内容
for (i, line) in enumerate(lines):
## 获取到第一格的内容并且以逗号分隔字符串
data = line[0].split(",")
# 第一行是字段名称,忽略处理
if i == 0:
continue
## 每条数据的id
guid = "%s-%s" % (set_type, i)
## 忽略验证集的空行
if set_type == "dev" and data[0] == "":
continue
## 忽略预测集的空行
if set_type == "test" and data[1] == "":
continue
## 如果是预测集,label指定使用"CS"进行预测
if set_type == "test":
text_a = tokenization.convert_to_unicode(data[1])
label = "CS"
## 第一个是文本分类,第二个是文本内容
else:
text_a = tokenization.convert_to_unicode(data[1])
label = tokenization.convert_to_unicode(data[0])
## 构建对象并且加进列表
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
计算PRF
def metric_fn(per_example_loss, label_ids, logits, is_real_example):
predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
accuracy = tf.metrics.accuracy(
labels=label_ids, predictions=predictions, weights=is_real_example)
# 计算PRF数值
auc = tf.metrics.auc(labels=label_ids, predictions=predictions, weights=is_real_example)
# PRF中的P
precision = tf.metrics.precision(labels=label_ids, predictions=predictions, weights=is_real_example)
# PRF中的R
recall = tf.metrics.recall(labels=label_ids, predictions=predictions, weights=is_real_example)
# F的值等于(2 * P * R) / (P + R)
loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example)
return {
"eval_accuracy": accuracy,
"eval_auc": auc,
"eval_precision": precision,
"eval_recall": recall,
"eval_loss": loss,
}
源码地址
https://download.csdn.net/download/rainbowBear/16747704