目录
2、pb文件,直接包含图结构和变量值,加载时只需要一个文件即可
3)tf.saved_model 保存得到一个pb文件和一个variables文件夹。
1、在训练的过程中保存的ckpt文件:
保存时主要有四个文件:
1)checkpoint:指示当前目录有哪些模型文件以及最新的模型文件
内容举例:
model_checkpoint_path: "model.ckpt-2625"
all_model_checkpoint_paths: "model.ckpt-2000"
all_model_checkpoint_paths: "model.ckpt-2625"
2)model.ckpt-2625.data-00000-of-00001
包含训练变量的文件,在bert训练过程中,约1.2g ,这是由于除了记录每个变量的值,还记录的一阶矩和二阶矩,即adam当中的v,u
3)model.ckpt-2625.index
描述变量的key和value的对应关系。
4)model.ckpt-2625.meta
描述整个网络的结构。
保存:可以由两种方式产生:
1).tf.train.Saver
saver=tf.train.Saver(max_to_keep=5) #max_to_keep=5意思就是保存最近的5个模型
saver.save(sess,'path',global_step=epoch)
2) estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
加载:
1)saver=tf.train.import_meta_graph('model/model.meta') #恢复计算图结构
saver.restore(sess, tf.train.latest_checkpoint("model/")) #恢复所有变量信息
2)estimator
def prepare(self):
tokenization.validate_case_matches_checkpoint(arg_dic['do_lower_case'], arg_dic['init_checkpoint'])
self.config = modeling.BertConfig.from_json_file(arg_dic['bert_config_file'])
if arg_dic['max_seq_length'] > self.config.max_position_embeddings:
raise ValueError(
"Cannot use sequence length %d because the BERT model "
"was only trained up to sequence length %d" %
(arg_dic['max_seq_length'], self.config.max_position_embeddings))
# tf.gfile.MakeDirs(self.out_dir)
self.tokenizer = tokenization.FullTokenizer(vocab_file=arg_dic['vocab_file'],
do_lower_case=arg_dic['do_lower_case'])
self.processor = SelfProcessor()
self.train_examples = self.processor.get_train_examples(arg_dic['data_dir'])
global label_list
label_list = self.processor.get_labels()
self.run_config = tf.estimator.RunConfig(
model_dir=arg_dic['output_dir'], save_checkpoints_steps=arg_dic['save_checkpoints_steps'],
tf_random_seed=None, save_summary_steps=100, session_config=None, keep_checkpoint_max=5,
keep_checkpoint_every_n_hours=10000, log_step_count_steps=100, )
self.predict_fn = tf.contrib.predictor.from_saved_model("pb_save_test")
def predict_on_ckpt(self, sentence):
if not self.ckpt_tool:
num_train_steps = int(len(self.train_examples) / arg_dic['train_batch_size'] * arg_dic['num_train_epochs'])
num_warmup_steps = int(num_train_steps * arg_dic['warmup_proportion'])
model_fn = model_fn_builder(bert_config=self.config, num_labels=len(label_list),
init_checkpoint=arg_dic['init_checkpoint'], learning_rate=arg_dic['learning_rate'],
num_train=num_train_steps, num_warmup=num_warmup_steps)
self.ckpt_tool = tf.estimator.Estimator(model_fn=model_fn, config=self.run_config, )
exam = self.processor.one_example(sentence) # 待预测的样本列表
feature = convert_single_example(0, exam, label_list, arg_dic['max_seq_length'], self.tokenizer)
predict_input_fn = input_fn_builder(features=[feature, ],
seq_length=arg_dic['max_seq_length'], is_training=False,
drop_remainder=False)
result = self.ckpt_tool.predict(input_fn=predict_input_fn) # 执行预测操作,得到一个生成器。
gailv = list(result)[0]["probabilities"].tolist()
pos = gailv.index(max(gailv)) # 定位到最大概率值索引,
return label_list[pos]
2、pb文件,直接包含图结构和变量值,加载时只需要一个文件即可
1)保存:
pb_file = os.path.join(arg_dic['pb_model_dir'], 'classification_model.pb')
graph = tf.Graph()
with graph.as_default():
input_ids = tf.placeholder(tf.int32, (None, arg_dic['max_seq_length']), 'input_ids')
input_mask = tf.placeholder(tf.int32, (None, arg_dic['max_seq_length']), 'input_mask')
bert_config = modeling.BertConfig.from_json_file(arg_dic['bert_config_file'])
loss, per_example_loss, logits, probabilities = create_classification_model(
bert_config=bert_config, is_training=False,
input_ids=input_ids, input_mask=input_mask, segment_ids=None, labels=None, num_labels=num_labels)
probabilities = tf.identity(probabilities, 'pred_prob')
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
latest_checkpoint = tf.train.latest_checkpoint(arg_dic['output_dir'])
saver.restore(sess, latest_checkpoint)
from tensorflow.python.framework import graph_util
tmp_g = graph_util.convert_variables_to_constants(sess, graph.as_graph_def(), ['pred_prob'])
# 存储二进制模型到文件中
with tf.gfile.GFile(pb_file, 'wb') as f:
f.write(tmp_g.SerializeToString())
return pb_file
except Exception as e:
print('fail to optimize the graph! %s', e)
2)加载:
def classification_model_fn(self, features, mode):
with tf.gfile.GFile(self.graph_path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
input_ids = features["input_ids"]
input_mask = features["input_mask"]
input_map = {"input_ids": input_ids, "input_mask": input_mask}
pred_probs = tf.import_graph_def(graph_def, name='', input_map=input_map, return_elements=['pred_prob:0'])
return EstimatorSpec(mode=mode, predictions={
'encodes': tf.argmax(pred_probs[0], axis=-1),
'score': tf.reduce_max(pred_probs[0], axis=-1)})
def predict_on_pb(self, sentence):
if not self.pbTool:
self.pbTool = tf.estimator.Estimator(model_fn=self.classification_model_fn, config=self.run_config, )
exam = self.processor.one_example(sentence) # 待预测的样本列表
feature = convert_single_example(0, exam, label_list, arg_dic['max_seq_length'], self.tokenizer)
predict_input_fn = input_fn_builder(features=[feature, ],
seq_length=arg_dic['max_seq_length'], is_training=False,
drop_remainder=False)
result = self.pbTool.predict(input_fn=predict_input_fn) # 执行预测操作,得到一个生成器。
ele = list(result)[0]
print('类别:{},置信度:{:.3f}'.format(label_list[ele['encodes']], ele['score']))
return label_list[ele['encodes']]
3)tf.saved_model 保存得到一个pb文件和一个variables文件夹。
保存两种方式:
1)graph = tf.Graph()
with graph.as_default():
input_ids = tf.placeholder(tf.int32, (None, arg_dic['max_seq_length']), 'input_ids')
input_mask = tf.placeholder(tf.int32, (None, arg_dic['max_seq_length']), 'input_mask')
bert_config = modeling.BertConfig.from_json_file(arg_dic['bert_config_file'])
loss, per_example_loss, logits, probabilities = create_classification_model(
bert_config=bert_config, is_training=False,
input_ids=input_ids, input_mask=input_mask, segment_ids=None, labels=None, num_labels=num_labels)
probabilities = tf.identity(probabilities, 'pred_prob')
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
latest_checkpoint = tf.train.latest_checkpoint(arg_dic['output_dir'])
saver.restore(sess, latest_checkpoint)
path_pb_model = "pb_save_test"
builder = tf.saved_model.builder.SavedModelBuilder(path_pb_model) # 创建一个保存模型的实例对象
# 构建需要在新会话中恢复的变量的 TensorInfo protobuf
input_ids1 = tf.saved_model.utils.build_tensor_info(input_ids)
input_mask1 = tf.saved_model.utils.build_tensor_info(input_mask)
probabilities1 = tf.saved_model.utils.build_tensor_info(probabilities)
# 构建 SignatureDef protobuf
signature_def = tf.saved_model.signature_def_utils.build_signature_def(
inputs={'input_ids': input_ids1, 'input_mask': input_mask1},
outputs={'probabilities':probabilities1 },
method_name='test')
# 将 graph 和变量等信息写入 MetaGraphDef protobuf
# 这里的 tags 里面的参数和 signature_def_map 字典里面的键都可以是自定义字符串,TensorFlow 为了方便使用,可以使用预定义的这些值
builder.add_meta_graph_and_variables(sess,
tags=[tf.saved_model.tag_constants.SERVING],
signature_def_map={tf.saved_model.signature_constants.CLASSIFY_INPUTS: signature_def})
# 将 MetaGraphDef 写入磁盘
builder.save()
2)estimator
def serving_input_receiver_fn():
"""
用于在serving时,接收数据
:return:
"""
feature_spec = {
"input_ids": tf.FixedLenFeature([arg_dic['max_seq_length']], tf.int64),
"input_mask": tf.FixedLenFeature([arg_dic['max_seq_length']], tf.int64),
"segment_ids": tf.FixedLenFeature([arg_dic['max_seq_length']], tf.int64),
}
serialized_tf_example = tf.placeholder(dtype=tf.string,
shape=[None],
name='input_example_tensor')
receiver_tensors = {'examples': serialized_tf_example}
features = tf.parse_example(serialized_tf_example, feature_spec)
return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)
if arg_dic['do_predict']:
estimator._export_to_tpu = False
estimator.export_savedmodel("pb_save_test", serving_input_receiver_fn)
模型的加载:
tokenization.validate_case_matches_checkpoint(arg_dic['do_lower_case'], arg_dic['init_checkpoint'])
self.config = modeling.BertConfig.from_json_file(arg_dic['bert_config_file'])
if arg_dic['max_seq_length'] > self.config.max_position_embeddings:
raise ValueError(
"Cannot use sequence length %d because the BERT model "
"was only trained up to sequence length %d" %
(arg_dic['max_seq_length'], self.config.max_position_embeddings))
# tf.gfile.MakeDirs(self.out_dir)
self.tokenizer = tokenization.FullTokenizer(vocab_file=arg_dic['vocab_file'],
do_lower_case=arg_dic['do_lower_case'])
self.processor = SelfProcessor()
global label_list
label_list = self.processor.get_labels()
self.predict_fn = tf.contrib.predictor.from_saved_model("/home/hadoop-health-alg/TextClassify_with_BERT/pb_save_test")
def predict_on_pb(self, sentence):
exam = self.processor.one_example(sentence) # 待预测的样本列表
feature = convert_single_example(0, exam, label_list, arg_dic['max_seq_length'], self.tokenizer)
features = dict()
features['input_ids'] = tf.train.Feature(int64_list=tf.train.Int64List(value=feature.input_ids))
features['input_mask'] = tf.train.Feature(int64_list=tf.train.Int64List(value=feature.input_mask))
features['segment_ids'] = tf.train.Feature(int64_list=tf.train.Int64List(value=feature.segment_ids))
tmp_feature = {"input_ids":feature.input_ids,"input_mask":feature.input_mask}
examples = []
example = tf.train.Example(features=tf.train.Features(feature=features))
examples.append(example.SerializeToString())
predictions = self.predict_fn({'examples': examples})
result = predictions['probabilities']
result = result.tolist()
pos = result[0].index(max(result[0])) # 定位到最大概率值索引,
print("hahahah",result,label_list,pos)
return label_list[pos]