bert --> 文本分类

文档地址

https://github.com/google-research/bert

准备工作

## Required parameters

## 数据集:包含训练数据集、验证数据集和预测数据集(附件可下载)
flags.DEFINE_string(
    "data_dir", "数据文件/tmp",
    "The input data dir. Should contain the .tsv files (or other data files) "
    "for the task.")

## bert的配置文件(附件可下载)
flags.DEFINE_string(
    "bert_config_file", "数据文件/chinese_L-12_H-768_A-12/bert_config.json",
    "The config json file corresponding to the pre-trained BERT model. "
    "This specifies the model architecture.")

## 启动的任务,最关键的代码块,重写DataProcessor,主要是数据集的处理工作
flags.DEFINE_string("task_name", "csv", "The name of the task to train.")

## 字典集
flags.DEFINE_string("vocab_file", "数据文件/chinese_L-12_H-768_A-12/vocab.txt",
                    "The vocabulary file that the BERT model was trained on.")

## 输出目录
flags.DEFINE_string(
    "output_dir", "output",
    "The output directory where the model checkpoints will be written.")

## Other parameters

## 基本预训练模型
flags.DEFINE_string(
    "init_checkpoint", "数据文件/chinese_L-12_H-768_A-12/bert_model.ckpt",
    "Initial checkpoint (usually from a pre-trained BERT model).")

## 是否对输入的文本小写
flags.DEFINE_bool(
    "do_lower_case", True,
    "Whether to lower case the input text. Should be True for uncased "
    "models and False for cased models.")

## 最大的序列长度
flags.DEFINE_integer(
    "max_seq_length", 128,
    "The maximum total input sequence length after WordPiece tokenization. "
    "Sequences longer than this will be truncated, and sequences shorter "
    "than this will be padded.")

## 是否训练
flags.DEFINE_bool("do_train", True, "Whether to run training.")

## 是否验证
flags.DEFINE_bool("do_eval", True, "Whether to run eval on the dev set.")

## 是否预测
flags.DEFINE_bool(
    "do_predict", True,
    "Whether to run the model in inference mode on the test set.")

关键实现代码

bert自带文本分类run_classifier.py,新增一个实现DataProcessor的数据集处理类即可处理自己想要处理的数据,然后加入到处理器即可。

处理器

processors = {
      "cola": ColaProcessor,
      "mnli": MnliProcessor,
      "mrpc": MrpcProcessor,
      "xnli": XnliProcessor,
      "csv": CsvProcessor
  }

基本的数据集处理器DataProcessor

class DataProcessor(object):
  """Base class for data converters for sequence classification data sets."""

  def get_train_examples(self, data_dir):
    """Gets a collection of `InputExample`s for the train set."""
    raise NotImplementedError()

  def get_dev_examples(self, data_dir):
    """Gets a collection of `InputExample`s for the dev set."""
    raise NotImplementedError()

  def get_test_examples(self, data_dir):
    """Gets a collection of `InputExample`s for prediction."""
    raise NotImplementedError()

  def get_labels(self):
    """Gets the list of labels for this data set."""
    raise NotImplementedError()

  @classmethod
  def _read_tsv(cls, input_file, quotechar=None):
    """Reads a tab separated value file."""
    with tf.gfile.Open(input_file, "r") as f:
      reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
      lines = []
      for line in reader:
        lines.append(line)
      return lines

自定义数据集处理器CsvProcessor

class CsvProcessor(DataProcessor):
  """Processor for the CoLA data set (GLUE version)."""

  ## 获取训练数据集
  def get_train_examples(self, data_dir):
    """See base class."""
    return self._create_examples(
        self._read_tsv(os.path.join(data_dir, "train.csv")), "train")
        
  ## 获取验证数据集
  def get_dev_examples(self, data_dir):
    """See base class."""
    return self._create_examples(
        self._read_tsv(os.path.join(data_dir, "dev.csv")), "dev")
        
  ## 获取预测数据集
  def get_test_examples(self, data_dir):
    """See base class."""
    return self._create_examples(
        self._read_tsv(os.path.join(data_dir, "test.csv")), "test")

  ## 分类的数组
  ## 以train.csv为例子,训练模型的时候主要有"O"和"CS"两大类
  def get_labels(self):
    """See base class."""
    return ["O", "CS"]

  ## 具体的数据处理逻辑
  def _create_examples(self, lines, set_type):
    """Creates examples for the training and dev sets."""
    examples = []
    	
	# i是每一行的索引
	# line是每一行内容
    for (i, line) in enumerate(lines):

	  ## 获取到第一格的内容并且以逗号分隔字符串
      data = line[0].split(",")

      # 第一行是字段名称,忽略处理
      if i == 0:
        continue
       
      ## 每条数据的id
      guid = "%s-%s" % (set_type, i)

	  ## 忽略验证集的空行
      if set_type == "dev" and data[0] == "":
          continue

	  ## 忽略预测集的空行
      if set_type == "test" and data[1] == "":
          continue

	  ## 如果是预测集,label指定使用"CS"进行预测
      if set_type == "test":
        text_a = tokenization.convert_to_unicode(data[1])
        label = "CS"

	  ## 第一个是文本分类,第二个是文本内容
      else:
        text_a = tokenization.convert_to_unicode(data[1])
        label = tokenization.convert_to_unicode(data[0])

	  ## 构建对象并且加进列表
      examples.append(
          InputExample(guid=guid, text_a=text_a, text_b=None, label=label))

    return examples

计算PRF

            def metric_fn(per_example_loss, label_ids, logits, is_real_example):
                predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
                accuracy = tf.metrics.accuracy(
                    labels=label_ids, predictions=predictions, weights=is_real_example)

                # 计算PRF数值
                auc = tf.metrics.auc(labels=label_ids, predictions=predictions, weights=is_real_example)

				# PRF中的P
                precision = tf.metrics.precision(labels=label_ids, predictions=predictions, weights=is_real_example)

				# PRF中的R
                recall = tf.metrics.recall(labels=label_ids, predictions=predictions, weights=is_real_example)

				# F的值等于(2 * P * R) / (P + R)

                loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example)
                return {
                    "eval_accuracy": accuracy,
                    "eval_auc": auc,
                    "eval_precision": precision,
                    "eval_recall": recall,
                    "eval_loss": loss,
                }

源码地址

https://download.csdn.net/download/rainbowBear/16747704
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值