tensorflow基于BERT训练文本分类模型保存为PB(saved model)并部署

本文档详细介绍了如何基于BERT构建文本分类任务,并将其模型保存为.pb格式以供服务器部署。首先,模型架构包括BERT和softmax层,通过ckpt转换为pb。在转换过程中,指定输入节点(如input_ids、input_mask、segment_ids、label_ids)和输出节点(probabilities)。在评估阶段,使用`serving_input_fn`指定输入节点,通过输入数据的张量调用模型计算并获取概率值。最后,文章提供了一个完整的BERT文本分类源码示例,包括数据预处理、模型训练、模型保存和转换,以及预测过程。

扫码关注“自然语言处理与算法”,讨论交流、顶会论文解读~
在这里插入图片描述

背景

基于BERT构建了文本分类任务,由于需要将模型部署至服务器,所以将模型保存为pb形式。
模型架构:BERT+softmax
模型保存策略:先将模型保存为ckpt形式然后转换为pb形式。
转换为pb形式时需要指定模型的输入节点,代码如下:

def serving_input_fn():
    # 保存模型为SaveModel格式
    # 采用最原始的feature方式,输入是feature Tensors。
    # 如果采用build_parsing_serving_input_receiver_fn,则输入是tf.Examples
     
    label_ids = tf.placeholder(tf.int32, [None, 3], name='label_ids')
    input_ids = tf.placeholder(tf.int32, [None, 200], name='input_ids')
    input_mask = tf.placeholder(tf.int32, [None, 200], name='input_mask')
    segment_ids = tf.placeholder(tf.int32, [None, 200], name='segment_ids')
    input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
   
   
        'label_ids': label_ids,
        'input_ids': input_ids,
        'input_mask': input_mask,
        'segment_ids': segment_ids,
    })()
    return input_fn

将ckpt转换为pb形式的代码,这里是在do_eval的时候:

if FLAGS.do_eval:
	 # trans_model_dir模型转换后输出目录
	 estimator._export_to_tpu = False
	 estimator.export_savedmodel(FLAGS.trans_model_dir, serving_input_fn)

生成的SaveModel:
在这里插入图片描述
variables文件下:
在这里插入图片描述
检查模型(命令行输入):

saved_model_cli show --dir  save_model/output --all

在这里插入图片描述
上图中有输入节点input_ids、input_mask、segment_ids、label_ids,以及输出节点probabilities。将测试数据对应的tensor输入到对应的输入节点即可调用训练好的模型进行计算,然后调用输出节点probabilities即可得到对应的值,这里得到的是概率值,是一个list列表,使用numpy中的argmax即可得到对应的label。这个地方回头会单独发一篇博客。

BERT文本分类完整源码

#!/usr/bin/env python
# -*- coding: utf-8 -*-


from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import csv
import os
from bert import tokenization, modeling, optimization
import tensorflow as tf
import pickle

flags = tf.flags

FLAGS = flags.FLAGS

## Required parameters
flags.DEFINE_string(
    "data_dir", None,
    "The input data dir. Should contain the .tsv files (or other data files) "
    "for the task.")

flags.DEFINE_string(
    "bert_config_file", None,
    "The config json file corresponding to the pre-trained BERT model. "
    "This specifies the model architecture.")

flags.DEFINE_string("task_name", None, "The name of the task to train.")

flags.DEFINE_string("vocab_file", None,
                    "The vocabulary file that the BERT model was trained on.")

flags.DEFINE_string(
    "output_dir", None,
    "The output directory where the model checkpoints will be written.")

flags.DEFINE_string(
    "trans_model_dir", None,
    "The trans_model_dir directory where the model will be written.")

## Other parameters

flags.DEFINE_string(
    "init_checkpoint", None,
    "Initial checkpoint (usually from a pre-trained BERT model).")

flags.DEFINE_bool(
    "do_lower_case", True,
    "Whether to lower case the input text. Should be True for uncased "
    "models and False for cased models.")

flags.DEFINE_integer(
    "max_seq_length", 200,
    "The maximum total input sequence length after WordPiece tokenization. "
    "Sequences longer than this will be truncated, and sequences shorter "
    "than this will be padded.")

flags.DEFINE_bool("do_train", False, "Whether to run training.")

flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.")

flags.DEFINE_bool(
    "do_predict", False,
    "Whether to run the model in inference mode on the test set.")

flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")

flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.")

flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predict.")

flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")

flags.DEFINE_float("num_train_epochs", 3.0,
                   "Total number of training epochs to perform.")

flags.DEFINE_float(
    "warmup_proportion", 0.1,
    "Proportion of training to perform linear learning rate warmup for. "
    "E.g., 0.1 = 10% of training.")

flags.DEFINE_integer("save_checkpoints_steps", 1000,
                     "How often to save the model checkpoint.")

flags.DEFINE_integer("iterations_per_loop", 1000,
                     "How many steps to make in each estimator call.")

flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")

tf.flags.DEFINE_string(
    "tpu_name", None,
    "The Cloud TPU to use for training. This should be either the name "
    "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
    "url.")

tf.flags.DEFINE_string(
    "tpu_zone", None,
    "[Optional] GCE zone where the Cloud TPU is located in. If not "
    "specified, we will attempt to automatically detect the GCE project from "
    "metadata.")

tf.flags.DEFINE_string(
    "gcp_project", None,
    "[Optional] Project name for the Cloud TPU-enabled project. If not "
    "specified, we will attempt to automatically detect the GCE project from "
    "metadata.")

tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")

flags.DEFINE_integer(
    "num_tpu_cores", 8,
    "Only used if `use_tpu` is True. Total number of TPU cores to use.")


class InputExample(object):
    """A single training/test example for simple sequence classification."""

    def __init__(self, guid, text_a, text_b=None, label=None):
        """Constructs a InputExample.
        Args:
          guid: Unique id for the example.
          text_a: string. The untokenized text of the first sequence. For single
            sequence tasks, only this sequence must be specified.
          text_b: (Optional) string. The untokenized text of the second sequence.
            Only must be specified for sequence pair tasks.
          label: (Optional) string. The label of the example. This should be
            specified for train and dev examples, but not for test examples.
        """
        self.guid = guid
        self.text_a = text_a
        self.text_b = text_b
        self.label = label


class PaddingInputExample(object):
    """Fake example so the num input examples is a multiple of the batch size.
    When running eval/predict on the TPU, we need to pad the number of examples
    to be a multiple of the batch size, because the TPU requires a fixed batch
    size. The alternative is to drop the last batch, which is bad because it means
    the entire output data won't be generated.
    We use this class instead of `None` because treating `None` as padding
    battches could cause silent errors.
    """


class InputFeatures(object):
    """A single set of features of data."""

    def __init__(self,
                 input_ids,
                 input_mask,
                 segment_ids,
                 label_id,
                 is_real_example=True):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.label_id = label_id
        self.is_real_example = is_real_example


class DataProcessor(object):
    """Base class for data converters for sequence classification data sets."""

    def get_train_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the train set."""
        raise NotImplementedError()

    def get_dev_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the dev set."""
        raise NotImplementedError()

    def get_test_examples(self, data_dir):
        """Gets a collection of `InputExample`s for prediction."""
        raise NotImplementedError()

    def get_labels(self):
        """Gets the list of labels for this data set."""
        raise NotImplementedError()

    @classmethod
    def _read_tsv(cls, input_file, quotechar=None):
        """Reads a tab separated value file."""
        with tf.gfile.Open(input_file, "r") as f:
            reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
            lines = []
            for line in reader:
                lines.append(line)
            return lines


'''自定义的processor'''

from sklearn.utils import shuffle
import pandas as pd


class MyProcessor(DataProcessor):

    # 读取文件
    def read_txt(self, filepath, type):
        df = pd.read_csv(filepath + '/' + type + '.csv', delimiter=",", names=['labels', 'text'], header=None, engine='python')
        df = shuffle(df)  
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值