bert保存模型(输入不能分开IteratorV2)

今天下载google官方bert代码,发现使用estimator训练成ckpt格式,转成pb以后,发现输入融合到了一起。
在这里插入图片描述
使用pb推理的时候,用see.run出现错误。

inputs = graph.get_tensor_by_name(prex + "IteratorV2:0")
# input_mask = graph.get_tensor_by_name(prex + "input_mask:0")
# token_type_ids = graph.get_tensor_by_name(prex + "token_type_ids:0")
logits = graph.get_tensor_by_name(prex + "loss/Softmax:0")
s_time = time.time()    
logits = sess.run([logits],feed_dict={inputs:[[input_feature.input_ids],[input_feature.input_mask],[input_feature.label_id],[input_feature.max_len],[input_feature.token_type_ids]]})
interval_time = time.time()-s_time
return logits,interval_time

—————————————————————————————————————————————
于是打算把模型保存成saved_model格式进行推理,在bert训练代码中加入。

def create_serving_input_receiver_fn():
    """ Builds a serving_inputer_receiver_fn
    Arguments
    ---------
    max_seq_length: int
        Specifies the sequence length
    Returns
    -------
    serving_input_receiver_fn()
    """

    def serving_input_receiver_fn():
        """ Creates an serving_input_receiver_fn for BERT"""

        input_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name="input_ids")
        input_mask = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name="input_mask")
        token_type_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name="token_type_ids")
        label_ids = tf.placeholder(tf.int32, [None, None], name="label_ids")
        valid_length = tf.placeholder(tf.int32, [None], name="valid_length")
        return tf.estimator.export.build_raw_serving_input_receiver_fn(
            {
                "input_ids": input_ids,
                "input_mask": input_mask,
                "token_type_ids": token_type_ids,
                "label_ids": label_ids,
                "valid_length": valid_length
            }
        )()

    return serving_input_receiver_fn
serving_input_receiver_fn = create_serving_input_receiver_fn()
model_dir =os.path.join(FLAGS.output_dir, "saved_model")
estimator.export_saved_model(model_dir, serving_input_receiver_fn)

保存成saved_model格式,这时候查看图中,有6个placeholder。
在这里插入图片描述
推理代码

predict_fn = predictor.from_saved_model(latest)
instance = {
              "input_ids": np.reshape(x_id, newshape=[1, -1]),
               "input_mask": np.reshape(x_mask, newshape=[1, -1]),
               "token_type_ids": np.reshape(x_segment, newshape=[1, -1]),
               "label_ids": np.reshape(x_label_id, newshape=[1,-1]),
               "valid_length":np.reshape(x_valid_length, newshape=[1])
                       }
result = predict_fn(instance)

完结撒花!!!!!

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值