今天下载google官方bert代码,发现使用estimator训练成ckpt格式,转成pb以后,发现输入融合到了一起。
使用pb推理的时候,用see.run出现错误。
inputs = graph.get_tensor_by_name(prex + "IteratorV2:0")
# input_mask = graph.get_tensor_by_name(prex + "input_mask:0")
# token_type_ids = graph.get_tensor_by_name(prex + "token_type_ids:0")
logits = graph.get_tensor_by_name(prex + "loss/Softmax:0")
s_time = time.time()
logits = sess.run([logits],feed_dict={inputs:[[input_feature.input_ids],[input_feature.input_mask],[input_feature.label_id],[input_feature.max_len],[input_feature.token_type_ids]]})
interval_time = time.time()-s_time
return logits,interval_time
—————————————————————————————————————————————
于是打算把模型保存成saved_model格式进行推理,在bert训练代码中加入。
def create_serving_input_receiver_fn():
""" Builds a serving_inputer_receiver_fn
Arguments
---------
max_seq_length: int
Specifies the sequence length
Returns
-------
serving_input_receiver_fn()
"""
def serving_input_receiver_fn():
""" Creates an serving_input_receiver_fn for BERT"""
input_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name="input_ids")
input_mask = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name="input_mask")
token_type_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name="token_type_ids")
label_ids = tf.placeholder(tf.int32, [None, None], name="label_ids")
valid_length = tf.placeholder(tf.int32, [None], name="valid_length")
return tf.estimator.export.build_raw_serving_input_receiver_fn(
{
"input_ids": input_ids,
"input_mask": input_mask,
"token_type_ids": token_type_ids,
"label_ids": label_ids,
"valid_length": valid_length
}
)()
return serving_input_receiver_fn
serving_input_receiver_fn = create_serving_input_receiver_fn()
model_dir =os.path.join(FLAGS.output_dir, "saved_model")
estimator.export_saved_model(model_dir, serving_input_receiver_fn)
保存成saved_model格式,这时候查看图中,有6个placeholder。
推理代码
predict_fn = predictor.from_saved_model(latest)
instance = {
"input_ids": np.reshape(x_id, newshape=[1, -1]),
"input_mask": np.reshape(x_mask, newshape=[1, -1]),
"token_type_ids": np.reshape(x_segment, newshape=[1, -1]),
"label_ids": np.reshape(x_label_id, newshape=[1,-1]),
"valid_length":np.reshape(x_valid_length, newshape=[1])
}
result = predict_fn(instance)
完结撒花!!!!!