java调用文本分类textrnn模型,勿踩坑

11 篇文章 0 订阅
6 篇文章 0 订阅
直接贴代码
# 将模型保存为可用于线上服务的文件(一个.pb文件,一个variables文件夹)
# print('Exporting trained model to', save_dir)
builder = tf.saved_model.builder.SavedModelBuilder(save_dir)

# 服务器专用代码

classification_signature = (
    tf.saved_model.signature_def_utils.build_signature_def(
        inputs={
            # "image"
            "input_x":
                tf.saved_model.utils.build_tensor_info(rnn.input_x),
            "dropout_keep_prob":
                tf.saved_model.utils.build_tensor_info(rnn.dropout_keep_prob)
        },
        outputs={
            # "classify"
            "output":
                tf.saved_model.utils.build_tensor_info(rnn.predictions)
            # "classification_outputs_scores":
            #     tf.saved_model.utils.build_tensor_info(model.logits)
        },
        # Prediction method name used in a SignatureDef.
        # PREDICT_METHOD_NAME = "tensorflow/serving/predict"
        method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))

legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')

builder.add_meta_graph_and_variables(
    # saved_model.tag_constants.SERVING = "saved_model.tag_constants.SERVING"
    sess, [tf.saved_model.tag_constants.SERVING],
    # 保存模型的方法名,与客户端的request.model_spec.signature_name对应
    signature_def_map={
        # "predict_image"
        "classification":
            classification_signature},
    legacy_init_op=legacy_init_op)

builder.save()

1、"input_x","output"这个千万不要乱写,因为你java调用的时候必须前后这个命名一致,否则会导致java调用模型预测结果与 python模型结果存在很大的差别

2、rnn是你的模型,rnn = TextRnn(config)

3、rnn.dropout_keep_prob是你的drop的命名方式,一定得和后续的一致

4、rnn.dropout_keep_prob与3一样

5、python跑模型的tensorflow的版本必须和java调用的版本一样!!!

模型格式如下:

现在模型准备好了就开始java调用了:

SavedModelBundle modelBundle = SavedModelBundle.load(path,"serve");

Session tfSession = modelBundle.session();

Operation operationPredict = modelBundle.graph().operation("output/predictions");

Output output = new Output(operationPredict,0);

Tensor keep_prob = Tensor.create(Float.parseFloat("1.0"));

“path”是你模型保存的路径

"output/predictions"和python中的命令相对应,一定得一样,不要乱命名,例如output_y,绝对结果出错

下一步对于输入“很幸运遇见你”,python获得的word_to_index文件把输入转换成相对应的位置标签a,

Tensor input_x = Tensor.create(a);
Tensor out = tfSession.runner().feed("input_x", input_x).feed("dropout_keep_prob",keep_prob).fetch(output).run().get(0);

转成输入的tensor,a是二维向量;"dropout_keep_prob"与之前的相对应,不要乱写!!!,keep_prob预测的时候就设置成1吧,训练的时候可以随机关闭一半左右,但测试的时候你需要全用的。

long[] temp = new long[1];
out.copyTo(temp);
short reskey = (short) temp[0];

获取对应的分类坐标,你训练的时候会获得每个类别对应的坐标,然后根据上面获得的reskey去获得相应的类别就ok了!

 

 

总结以上主要有几个点:

   1、 不要python生成pb时的参数命名和java调用的时候不一致

   2、python和java的tensorflow版本必须一致,1.10和1.12都会报错

   3、输入转换成坐标向量和类别坐标这两个map中的对应顺序不要错了

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值