直接贴代码 # 将模型保存为可用于线上服务的文件(一个.pb文件,一个variables文件夹) # print('Exporting trained model to', save_dir) builder = tf.saved_model.builder.SavedModelBuilder(save_dir) # 服务器专用代码 classification_signature = ( tf.saved_model.signature_def_utils.build_signature_def( inputs={ # "image" "input_x": tf.saved_model.utils.build_tensor_info(rnn.input_x), "dropout_keep_prob": tf.saved_model.utils.build_tensor_info(rnn.dropout_keep_prob) }, outputs={ # "classify" "output": tf.saved_model.utils.build_tensor_info(rnn.predictions) # "classification_outputs_scores": # tf.saved_model.utils.build_tensor_info(model.logits) }, # Prediction method name used in a SignatureDef. # PREDICT_METHOD_NAME = "tensorflow/serving/predict" method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)) legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op') builder.add_meta_graph_and_variables( # saved_model.tag_constants.SERVING = "saved_model.tag_constants.SERVING" sess, [tf.saved_model.tag_constants.SERVING], # 保存模型的方法名,与客户端的request.model_spec.signature_name对应 signature_def_map={ # "predict_image" "classification": classification_signature}, legacy_init_op=legacy_init_op) builder.save()
1、"input_x","output"这个千万不要乱写,因为你java调用的时候必须前后这个命名一致,否则会导致java调用模型预测结果与 python模型结果存在很大的差别
2、rnn是你的模型,rnn = TextRnn(config)
3、rnn.dropout_keep_prob是你的drop的命名方式,一定得和后续的一致
4、rnn.dropout_keep_prob与3一样
5、python跑模型的tensorflow的版本必须和java调用的版本一样!!!
模型格式如下:
现在模型准备好了就开始java调用了:
SavedModelBundle modelBundle = SavedModelBundle.load(path,"serve"); Session tfSession = modelBundle.session(); Operation operationPredict = modelBundle.graph().operation("output/predictions"); Output output = new Output(operationPredict,0); Tensor keep_prob = Tensor.create(Float.parseFloat("1.0"));
“path”是你模型保存的路径
"output/predictions"和python中的命令相对应,一定得一样,不要乱命名,例如output_y,绝对结果出错
下一步对于输入“很幸运遇见你”,python获得的word_to_index文件把输入转换成相对应的位置标签a,
Tensor input_x = Tensor.create(a); Tensor out = tfSession.runner().feed("input_x", input_x).feed("dropout_keep_prob",keep_prob).fetch(output).run().get(0);
转成输入的tensor,a是二维向量;"dropout_keep_prob"与之前的相对应,不要乱写!!!,keep_prob预测的时候就设置成1吧,训练的时候可以随机关闭一半左右,但测试的时候你需要全用的。
long[] temp = new long[1]; out.copyTo(temp); short reskey = (short) temp[0];
获取对应的分类坐标,你训练的时候会获得每个类别对应的坐标,然后根据上面获得的reskey去获得相应的类别就ok了!
总结以上主要有几个点:
1、 不要python生成pb时的参数命名和java调用的时候不一致
2、python和java的tensorflow版本必须一致,1.10和1.12都会报错
3、输入转换成坐标向量和类别坐标这两个map中的对应顺序不要错了