bert 文本分类_用Bert进行文本分类

该博客介绍了如何使用BERT进行文本分类,包括创建DataProcessor类以处理数据,定义get_train_examples等方法读取数据,注册Processor,以及设置训练、验证和测试流程的参数。最后,说明了如何从BERT模型获取句向量和词向量,以便用于自己的分类模型。
摘要由CSDN通过智能技术生成
BERT中文文本相似度计算与文本分类 - Welcome to AI World​terrifyzhao.github.io
f082561b07aec23556fd6a57fd313133.png

1. DataProcessor

顾名思义,Processor就是用来获取对应的训练集、验证集、测试集的数据与label的数据,并把这些数据 喂给BERT的,而我们要做的就是自定义新的Processor并重写这4个方法,也就是说我们只需要提供我们自己场景对应的数据。需要注意的有:
  • 读取的数据需要封装成一个InputExample的对象并添加到list中,注意这里有一个guid的参数,这个参数是必填的,是用来区分每一条数据的。
  • get_labels方法返回的是一个数组,因为相似度问题可以理解为分类问题,所以返回的标签只有0和1,注意,这里我返回的是参数是字符串,所以在重写获取数据的方法时InputExample中的label也要传字符串的数据
  • 接下来还需要给Processor加一个名字,让我们的在运行时告诉代码我们要执行哪一个Processor

1.1 创建自己的Processor类

实现流程

  • 创建自己的数据处理类
  • 实现数据读取的方法: get_train_examples, get_dev_examples, get_test_examples, get_labels
  • 其他辅助读取数据的方法(如有必要)
  • main 中注册
# 基类
class DataProcessor(object):
"""Base class for data converters for sequence classification data sets."""

    def get_train_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the train set."""
        raise NotImplementedError()

    def get_dev_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the dev set."""
        raise NotImplementedError()

    def get_test_examples(self, data_dir):
        """Gets a collection of `InputExample`s for prediction."""
        raise NotImplementedError()

    def get_labels(self):
        """Gets the list of labels for this data set."""
        raise NotImplementedError()

    @classmethod
    def _read_tsv(cls, input_file, quotechar=None):
        """Reads a tab separated value file."""
        with tf.gfile.Open(input_file, "r") as f:
          reader = csv.reader(f, delimiter="t", quotechar=quotechar)
          lines = []
          for line in reader:
            lines.append(line)
          return lines


# TODO: 创建任务的数据处理类,从DataProcessor继承
class YourProcessor(DataProcessor):
    """Processor for the simple data set."""

    def __init__(self):
        pass

    # bert框架调用,返回模型可用的数据样本
    # 1. read file
    # 2. create examples using InputExample Class
    # 3. labels

    # 0. call to gen and return train/dev/test set 
    def get_train_examples(self, data_dir):
        return self._create_examples(
            self._read_csv(os.path.join(data_dir, 'trainData.csv')), 'train'
        )

    def get_dev_examples(self, data_dir):
        return self._create_examples(
            self._read_csv(os.path.join(data_dir, 'devData.csv')), 'dev'
        )

    def get_test_examples(self, data_dir):
        return self._create_examples(
            self._read_csv(os.path.join(data_dir, 'testData.csv')), 'test'
        )

    # 1. read data from file
    @classmethod
    def _read_csv(cls, input_file, split=',', quotechar=None):
        """Reads a n separated value file."""
        with tf.gfile.Open(input_file, "r") as f:
            reader = csv.reader(f, delimiter=split, quotechar=quotechar)
            lines = []
            for line in reader:
                lines.append(line)
            return lines

    # 2. create Bert readable example using InputExample
    def _create_examples(self, lines, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            guid = "%s-%s" % (set_type, i)
            text_a = tokenization.convert_to_unicode(line[0])
            label = tokenization.convert_to_unicode(line[1])
            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=None, label=label))

        return examples

    # 3. labels
    def get_labels(self):
        """See base class."""
        return ['0', '1']

1.2 注册

def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)

    processors = {
        "cola": ColaProcessor,
        "mnli": MnliProcessor,
        "mrpc": MrpcProcessor,
        "xnli": XnliProcessor,
        # TODO: 将定义好的Processor在此注册
        "mytask": ……,
    }

  ……

2. 运行

执行脚本的方法如下,在执行时需要的具体参数在下面详细解释:

python run_classifier.py 
  --data_dir=$MY_DATASET 
  --task_name=sim 
  --vocab_file=$BERT_BASE_DIR/vocab.txt 
  --bert_config_file=$BERT_BASE_DIR/bert_config.json 
  --output_dir=/tmp/sim_model/ 
  --do_train=true 
  --do_eval=true 
  --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt 
  --max_seq_length=128 
  --train_batch_size=32 
  --learning_rate=5e-5
  --num_train_epochs=2.0

2.1 数据和模型参数

执行run_classifier.py时,有5个必填参数: data_dir, task_name, vocab_file, bert_config_file, out_put

8c402665829702417c0fa75ef8499611.png
# import tensorflow as tf
# flags = tf.flags
# tf的参数解析
if __name__ == "__main__":
  flags.mark_flag_as_required("data_dir")
  flags.mark_flag_as_required("task_name")
  flags.mark_flag_as_required("vocab_file")
  flags.mark_flag_as_required("bert_config_file")
  flags.mark_flag_as_required("output_dir")
  tf.app.run()

2.2 流程参数

  • 控制训练,验证,测试流程

9aa4c1341252943b676616ea9d637d88.png

2.3 训练参数

  • 训练任务的具体参数

a68907df38ee03992eaeb563fb7909d1.png

3. 嵌入自己的训练模型

利用bert获得句/词向量,喂给自己的模型: (具体方式见代码注释)

  • 向量获取
    • 句向量:model.get_pooled_output()
    • 词向量:model.get_sequence_output()
  • 模型添加
  • 返回结果
def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, labels, num_labels, use_one_hot_embeddings):
    """Creates a classification model."""
    model = modeling.BertModel(
            config=bert_config,
            is_training=is_training,
            input_ids=input_ids,
            input_mask=input_mask,
            token_type_ids=segment_ids,
            use_one_hot_embeddings=use_one_hot_embeddings)

    # In the demo, we are doing a simple classification task on the entire segment.
    # If you want to use the token-level output, use model.get_sequence_output() instead.

    # TODO: 此处样本经过了bert模型,如果需要取出样本的向量喂给其他模型(分类or其他),可以用以下两个方法:
    # 1. 句向量:model.get_pooled_output()
    # 2. 词向量:model.get_sequence_output()
    #
    # 后续模型在此处继续顺序执行
    output_layer = model.get_pooled_output()

    hidden_size = output_layer.shape[-1].value

    output_weights = tf.get_variable(
        "output_weights", [num_labels, hidden_size],
        initializer=tf.truncated_normal_initializer(stddev=0.02))

    output_bias = tf.get_variable(
        "output_bias", [num_labels], initializer=tf.zeros_initializer())

    with tf.variable_scope("loss"):
        if is_training:
        # I.e., 0.1 dropout
        output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)

        logits = tf.matmul(output_layer, output_weights, transpose_b=True)
        logits = tf.nn.bias_add(logits, output_bias)
        probabilities = tf.nn.softmax(logits, axis=-1)
        log_probs = tf.nn.log_softmax(logits, axis=-1)

        one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32)

        per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
        loss = tf.reduce_mean(per_example_loss)

        return (loss, per_example_loss, logits, probabilities)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值