bert 三种模型保存的方式以及调用方法总结(ckpt,单文件pb,tf_serving使用的pb)

目录

1、在训练的过程中保存的ckpt文件:

2、pb文件,直接包含图结构和变量值,加载时只需要一个文件即可

3)tf.saved_model 保存得到一个pb文件和一个variables文件夹。


1、在训练的过程中保存的ckpt文件:

保存时主要有四个文件:

1)checkpoint:指示当前目录有哪些模型文件以及最新的模型文件

内容举例:

  model_checkpoint_path: "model.ckpt-2625"
  all_model_checkpoint_paths: "model.ckpt-2000"
  all_model_checkpoint_paths: "model.ckpt-2625"

2)model.ckpt-2625.data-00000-of-00001

包含训练变量的文件,在bert训练过程中,约1.2g ,这是由于除了记录每个变量的值,还记录的一阶矩和二阶矩,即adam当中的v,u

3)model.ckpt-2625.index

描述变量的key和value的对应关系。

4)model.ckpt-2625.meta

描述整个网络的结构。

保存:可以由两种方式产生:

1).tf.train.Saver


saver=tf.train.Saver(max_to_keep=5) #max_to_keep=5意思就是保存最近的5个模型
saver.save(sess,'path',global_step=epoch)

2) estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)

加载:

1)saver=tf.train.import_meta_graph('model/model.meta') #恢复计算图结构
saver.restore(sess, tf.train.latest_checkpoint("model/")) #恢复所有变量信息

2)estimator
 
def prepare(self):
        tokenization.validate_case_matches_checkpoint(arg_dic['do_lower_case'], arg_dic['init_checkpoint'])
        self.config = modeling.BertConfig.from_json_file(arg_dic['bert_config_file'])

        if arg_dic['max_seq_length'] > self.config.max_position_embeddings:
            raise ValueError(
                "Cannot use sequence length %d because the BERT model "
                "was only trained up to sequence length %d" %
                (arg_dic['max_seq_length'], self.config.max_position_embeddings))

        # tf.gfile.MakeDirs(self.out_dir)
        self.tokenizer = tokenization.FullTokenizer(vocab_file=arg_dic['vocab_file'],
                                                    do_lower_case=arg_dic['do_lower_case'])

        self.processor = SelfProcessor()
        self.train_examples = self.processor.get_train_examples(arg_dic['data_dir'])
        global label_list
        label_list = self.processor.get_labels()

        self.run_config = tf.estimator.RunConfig(
            model_dir=arg_dic['output_dir'], save_checkpoints_steps=arg_dic['save_checkpoints_steps'],
            tf_random_seed=None, save_summary_steps=100, session_config=None, keep_checkpoint_max=5,
            keep_checkpoint_every_n_hours=10000, log_step_count_steps=100, )
        self.predict_fn = tf.contrib.predictor.from_saved_model("pb_save_test")
        
    def predict_on_ckpt(self, sentence):
        if not self.ckpt_tool:
            num_train_steps = int(len(self.train_examples) / arg_dic['train_batch_size'] * arg_dic['num_train_epochs'])
            num_warmup_steps = int(num_train_steps * arg_dic['warmup_proportion'])

            model_fn = model_fn_builder(bert_config=self.config, num_labels=len(label_list),
                                        init_checkpoint=arg_dic['init_checkpoint'], learning_rate=arg_dic['learning_rate'],
                                        num_train=num_train_steps, num_warmup=num_warmup_steps)

            self.ckpt_tool = tf.estimator.Estimator(model_fn=model_fn, config=self.run_config, )
        exam = self.processor.one_example(sentence)  # 待预测的样本列表
        feature = convert_single_example(0, exam, label_list, arg_dic['max_seq_length'], self.tokenizer)

        predict_input_fn = input_fn_builder(features=[feature, ],
                                            seq_length=arg_dic['max_seq_length'], is_training=False,
                                            drop_remainder=False)
        result = self.ckpt_tool.predict(input_fn=predict_input_fn)  # 执行预测操作,得到一个生成器。
        gailv = list(result)[0]["probabilities"].tolist()
        pos = gailv.index(max(gailv))  # 定位到最大概率值索引,
        return label_list[pos]

2、pb文件,直接包含图结构和变量值,加载时只需要一个文件即可

1)保存:

 pb_file = os.path.join(arg_dic['pb_model_dir'], 'classification_model.pb')
        graph = tf.Graph()
        with graph.as_default():
            input_ids = tf.placeholder(tf.int32, (None, arg_dic['max_seq_length']), 'input_ids')
            input_mask = tf.placeholder(tf.int32, (None, arg_dic['max_seq_length']), 'input_mask')
            bert_config = modeling.BertConfig.from_json_file(arg_dic['bert_config_file'])
            loss, per_example_loss, logits, probabilities = create_classification_model(
                bert_config=bert_config, is_training=False,
                input_ids=input_ids, input_mask=input_mask, segment_ids=None, labels=None, num_labels=num_labels)

            probabilities = tf.identity(probabilities, 'pred_prob')
            saver = tf.train.Saver()

            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                latest_checkpoint = tf.train.latest_checkpoint(arg_dic['output_dir'])
                saver.restore(sess, latest_checkpoint)
                from tensorflow.python.framework import graph_util
                tmp_g = graph_util.convert_variables_to_constants(sess, graph.as_graph_def(), ['pred_prob'])
        # 存储二进制模型到文件中
        with tf.gfile.GFile(pb_file, 'wb') as f:
            f.write(tmp_g.SerializeToString())
        return pb_file
    except Exception as e:
        print('fail to optimize the graph! %s', e)

2)加载:

def classification_model_fn(self, features, mode):
        with tf.gfile.GFile(self.graph_path, 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        input_map = {"input_ids": input_ids, "input_mask": input_mask}
        pred_probs = tf.import_graph_def(graph_def, name='', input_map=input_map, return_elements=['pred_prob:0'])

        return EstimatorSpec(mode=mode, predictions={
            'encodes': tf.argmax(pred_probs[0], axis=-1),
            'score': tf.reduce_max(pred_probs[0], axis=-1)})


   def predict_on_pb(self, sentence):
        if not self.pbTool:
            self.pbTool = tf.estimator.Estimator(model_fn=self.classification_model_fn, config=self.run_config, )
        exam = self.processor.one_example(sentence)  # 待预测的样本列表
        feature = convert_single_example(0, exam, label_list, arg_dic['max_seq_length'], self.tokenizer)
        predict_input_fn = input_fn_builder(features=[feature, ],
                                            seq_length=arg_dic['max_seq_length'], is_training=False,
                                            drop_remainder=False)
        result = self.pbTool.predict(input_fn=predict_input_fn)  # 执行预测操作,得到一个生成器。
        ele = list(result)[0]
        print('类别:{},置信度:{:.3f}'.format(label_list[ele['encodes']], ele['score']))
        return label_list[ele['encodes']]

3)tf.saved_model 保存得到一个pb文件和一个variables文件夹。

保存两种方式:

1)graph = tf.Graph()
        with graph.as_default():
            input_ids = tf.placeholder(tf.int32, (None, arg_dic['max_seq_length']), 'input_ids')
            input_mask = tf.placeholder(tf.int32, (None, arg_dic['max_seq_length']), 'input_mask')
            bert_config = modeling.BertConfig.from_json_file(arg_dic['bert_config_file'])
            loss, per_example_loss, logits, probabilities = create_classification_model(
                bert_config=bert_config, is_training=False,
                input_ids=input_ids, input_mask=input_mask, segment_ids=None, labels=None, num_labels=num_labels)

            probabilities = tf.identity(probabilities, 'pred_prob')
            saver = tf.train.Saver()

            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                latest_checkpoint = tf.train.latest_checkpoint(arg_dic['output_dir'])
                saver.restore(sess, latest_checkpoint)
                path_pb_model = "pb_save_test"
                builder = tf.saved_model.builder.SavedModelBuilder(path_pb_model)  # 创建一个保存模型的实例对象    
                # 构建需要在新会话中恢复的变量的 TensorInfo protobuf
                input_ids1 = tf.saved_model.utils.build_tensor_info(input_ids)
                input_mask1 = tf.saved_model.utils.build_tensor_info(input_mask)
                probabilities1 = tf.saved_model.utils.build_tensor_info(probabilities)
                # 构建 SignatureDef protobuf
                signature_def = tf.saved_model.signature_def_utils.build_signature_def(
                inputs={'input_ids': input_ids1, 'input_mask': input_mask1},
                outputs={'probabilities':probabilities1 },
                method_name='test')
                # 将 graph 和变量等信息写入 MetaGraphDef protobuf
                # 这里的 tags 里面的参数和 signature_def_map 字典里面的键都可以是自定义字符串,TensorFlow 为了方便使用,可以使用预定义的这些值
                builder.add_meta_graph_and_variables(sess, 
                                         tags=[tf.saved_model.tag_constants.SERVING], 
                                         signature_def_map={tf.saved_model.signature_constants.CLASSIFY_INPUTS: signature_def}) 

                # 将 MetaGraphDef 写入磁盘
                builder.save()
2)estimator
def serving_input_receiver_fn():
    """
    用于在serving时,接收数据
    :return:
    """
    feature_spec = {
        "input_ids": tf.FixedLenFeature([arg_dic['max_seq_length']], tf.int64),
        "input_mask": tf.FixedLenFeature([arg_dic['max_seq_length']], tf.int64),
        "segment_ids": tf.FixedLenFeature([arg_dic['max_seq_length']], tf.int64),
    }
    serialized_tf_example = tf.placeholder(dtype=tf.string,
                                           shape=[None],
                                           name='input_example_tensor')
    receiver_tensors = {'examples': serialized_tf_example}
    features = tf.parse_example(serialized_tf_example, feature_spec)
    return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)

if arg_dic['do_predict']:
        estimator._export_to_tpu = False
        estimator.export_savedmodel("pb_save_test", serving_input_receiver_fn)

模型的加载:

 tokenization.validate_case_matches_checkpoint(arg_dic['do_lower_case'], arg_dic['init_checkpoint'])
        self.config = modeling.BertConfig.from_json_file(arg_dic['bert_config_file'])

        if arg_dic['max_seq_length'] > self.config.max_position_embeddings:
            raise ValueError(
                "Cannot use sequence length %d because the BERT model "
                "was only trained up to sequence length %d" %
                (arg_dic['max_seq_length'], self.config.max_position_embeddings))

        # tf.gfile.MakeDirs(self.out_dir)
        self.tokenizer = tokenization.FullTokenizer(vocab_file=arg_dic['vocab_file'],
                                                    do_lower_case=arg_dic['do_lower_case'])

        self.processor = SelfProcessor()
        global label_list
        label_list = self.processor.get_labels()
        self.predict_fn = tf.contrib.predictor.from_saved_model("/home/hadoop-health-alg/TextClassify_with_BERT/pb_save_test")


    def predict_on_pb(self, sentence):

        exam = self.processor.one_example(sentence)  # 待预测的样本列表
        feature = convert_single_example(0, exam, label_list, arg_dic['max_seq_length'], self.tokenizer)
        features = dict()
        features['input_ids'] = tf.train.Feature(int64_list=tf.train.Int64List(value=feature.input_ids))
        features['input_mask'] = tf.train.Feature(int64_list=tf.train.Int64List(value=feature.input_mask))
        features['segment_ids'] = tf.train.Feature(int64_list=tf.train.Int64List(value=feature.segment_ids))
        tmp_feature = {"input_ids":feature.input_ids,"input_mask":feature.input_mask}
        examples = []
        example = tf.train.Example(features=tf.train.Features(feature=features))
        examples.append(example.SerializeToString())
        predictions = self.predict_fn({'examples': examples})
        result  = predictions['probabilities']
        result = result.tolist()
        pos = result[0].index(max(result[0]))  # 定位到最大概率值索引,

        print("hahahah",result,label_list,pos)
        return label_list[pos]

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

samoyan

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值