tensenflow java_tensorflow serving java

系列

背景介绍

这篇文章是tensorflow serving java api使用的参考案例,基本上把TFS的核心API的用法都介绍清楚。案例主要分为三部分:

动态更新模型:用于在TFS处于runtime时候动态加载模型。

获取模型状态:用于获取加载的模型的基本信息。

在线模型预测:进行在线预测,分类等操作,着重介绍在线预测。

因为模型的预测需要参考模型内部变量,所以可以先行通过TFS的REST接口获取TF模型的元数据然后才能构建TFS的RPC请求对象。

TFS 使用入门

模型源数据获取

curl http://host:port/v1/models/${MODEL_NAME}[/versions/${MODEL_VERSION}]/metadata

说明:

public static void getModelStatus() {

// 1、设置访问的RPC协议的host和port

ManagedChannel channel = ManagedChannelBuilder.forAddress(host, port).usePlaintext().build();

// 2、构建PredictionServiceBlockingStub对象

PredictionServiceGrpc.PredictionServiceBlockingStub predictionServiceBlockingStub =

PredictionServiceGrpc.newBlockingStub(channel);

// 3、设置待获取的模型

Model.ModelSpec modelSpec = Model.ModelSpec.newBuilder()

.setName("wdl_model").build();

// 4、构建获取元数据的请求

GetModelMetadata.GetModelMetadataRequest modelMetadataRequest =

GetModelMetadata.GetModelMetadataRequest.newBuilder()

.setModelSpec(modelSpec)

.addAllMetadataField(Arrays.asList("signature_def"))

.build();

// 5、获取元数据

GetModelMetadata.GetModelMetadataResponse getModelMetadataResponse =

predictionServiceBlockingStub.getModelMetadata(modelMetadataRequest);

channel.shutdownNow();

}

说明:

Model.ModelSpec.newBuilder绑定需要访问的模型的名字。

GetModelMetadataRequest中addAllMetadataField绑定curl命令返回的metadata当中的signature_def字段。

动态更新模型

public static void addNewModel() {

// 1、构建动态更新模型1

ModelServerConfigOuterClass.ModelConfig modelConfig1 =

ModelServerConfigOuterClass.ModelConfig.newBuilder()

.setBasePath("/models/new_model")

.setName("new_model")

.setModelType(ModelServerConfigOuterClass.ModelType.TENSORFLOW)

.build();

// 2、构建动态更新模型2

ModelServerConfigOuterClass.ModelConfig modelConfig2 =

ModelServerConfigOuterClass.ModelConfig.newBuilder()

.setBasePath("/models/wdl_model")

.setName("wdl_model")

.setModelType(ModelServerConfigOuterClass.ModelType.TENSORFLOW)

.build();

// 3、合并动态更新模型到ModelConfigList对象中

ModelServerConfigOuterClass.ModelConfigList modelConfigList =

ModelServerConfigOuterClass.ModelConfigList.newBuilder()

.addConfig(modelConfig1)

.addConfig(modelConfig2)

.build();

// 4、添加到ModelConfigList到ModelServerConfig对象当中

ModelServerConfigOuterClass.ModelServerConfig modelServerConfig =

ModelServerConfigOuterClass.ModelServerConfig.newBuilder()

.setModelConfigList(modelConfigList)

.build();

// 5、构建ReloadConfigRequest并绑定ModelServerConfig对象。

ModelManagement.ReloadConfigRequest reloadConfigRequest =

ModelManagement.ReloadConfigRequest.newBuilder()

.setConfig(modelServerConfig)

.build();

// 6、构建modelServiceBlockingStub访问句柄

ManagedChannel channel = ManagedChannelBuilder.forAddress(host, port).usePlaintext().build();

ModelServiceGrpc.ModelServiceBlockingStub modelServiceBlockingStub =

ModelServiceGrpc.newBlockingStub(channel);

ModelManagement.ReloadConfigResponse reloadConfigResponse =

modelServiceBlockingStub.handleReloadConfigRequest(reloadConfigRequest);

System.out.println(reloadConfigResponse.getStatus().getErrorMessage());

channel.shutdownNow();

}

说明:

动态更新模型是一个全量的模型加载,在发布A模型后想动态发布B模型需要同时传递模型A和B的信息。

再次强调,需要全量更新,全量更新,全量更新!!!

在线模型预测

public static void doPredict() throws Exception {

// 1、构建feature

Map featureMap = new HashMap<>();

featureMap.put("match_type", feature(""));

featureMap.put("position", feature(0.0f));

featureMap.put("brand_prefer_1d", feature(0.0f));

featureMap.put("brand_prefer_1m", feature(0.0f));

featureMap.put("brand_prefer_1w", feature(0.0f));

featureMap.put("brand_prefer_2w", feature(0.0f));

featureMap.put("browse_norm_score_1d", feature(0.0f));

featureMap.put("browse_norm_score_1w", feature(0.0f));

featureMap.put("browse_norm_score_2w", feature(0.0f));

featureMap.put("buy_norm_score_1d", feature(0.0f));

featureMap.put("buy_norm_score_1w", feature(0.0f));

featureMap.put("buy_norm_score_2w", feature(0.0f));

featureMap.put("cate1_prefer_1d", feature(0.0f));

featureMap.put("cate1_prefer_2d", feature(0.0f));

featureMap.put("cate1_prefer_1m", feature(0.0f));

featureMap.put("cate1_prefer_1w", feature(0.0f));

featureMap.put("cate1_prefer_2w", feature(0.0f));

featureMap.put("cate2_prefer_1d", feature(0.0f));

featureMap.put("cate2_prefer_1m", feature(0.0f));

featureMap.put("cate2_prefer_1w", feature(0.0f));

featureMap.put("cate2_prefer_2w", feature(0.0f));

featureMap.put("cid_prefer_1d", feature(0.0f));

featureMap.put("cid_prefer_1m", feature(0.0f));

featureMap.put("cid_prefer_1w", feature(0.0f));

featureMap.put("cid_prefer_2w", feature(0.0f));

featureMap.put("user_buy_rate_1d", feature(0.0f));

featureMap.put("user_buy_rate_2w", feature(0.0f));

featureMap.put("user_click_rate_1d", feature(0.0f));

featureMap.put("user_click_rate_1w", feature(0.0f));

Features features = Features.newBuilder().putAllFeature(featureMap).build();

Example example = Example.newBuilder().setFeatures(features).build();

// 2、构建Predict请求

Predict.PredictRequest.Builder predictRequestBuilder = Predict.PredictRequest.newBuilder();

// 3、构建模型请求维度ModelSpec,绑定模型名和预测的签名

Model.ModelSpec.Builder modelSpecBuilder = Model.ModelSpec.newBuilder();

modelSpecBuilder.setName("wdl_model");

modelSpecBuilder.setSignatureName("predict");

predictRequestBuilder.setModelSpec(modelSpecBuilder);

// 4、构建预测请求的维度信息DIM对象

TensorShapeProto.Dim dim = TensorShapeProto.Dim.newBuilder().setSize(300).build();

TensorShapeProto shapeProto = TensorShapeProto.newBuilder().addDim(dim).build();

TensorProto.Builder tensor = TensorProto.newBuilder();

tensor.setTensorShape(shapeProto);

tensor.setDtype(DataType.DT_STRING);

// 5、批量绑定预测请求的数据

for (int i=0; i<300; i++) {

tensor.addStringVal(example.toByteString());

}

predictRequestBuilder.putInputs("examples", tensor.build());

// 6、构建PredictionServiceBlockingStub对象准备预测

ManagedChannel channel = ManagedChannelBuilder.forAddress(host, port).usePlaintext().build();

PredictionServiceGrpc.PredictionServiceBlockingStub predictionServiceBlockingStub =

PredictionServiceGrpc.newBlockingStub(channel);

// 7、执行预测

Predict.PredictResponse predictResponse =

predictionServiceBlockingStub.predict(predictRequestBuilder.build());

// 8、解析请求结果

List floatList = predictResponse

.getOutputsOrThrow("probabilities")

.getFloatValList();

}

说明:

TFS的RPC请求过程中设置的参数需要考虑TF模型的数据结构。

TFS的RPC请求有同步和异步两种方式,上述只展示同步方式。

TF模型结构

{

"model_spec": {

"name": "wdl_model",

"signature_name": "",

"version": "4"

},

"metadata": {

"signature_def": {

"signature_def": {

"predict": {

"inputs": {

"examples": {

"dtype": "DT_STRING",

"tensor_shape": {

"dim": [{

"size": "-1",

"name": ""

}],

"unknown_rank": false

},

"name": "input_example_tensor:0"

}

},

"outputs": {

"logistic": {

"dtype": "DT_FLOAT",

"tensor_shape": {

"dim": [{

"size": "-1",

"name": ""

},

{

"size": "1",

"name": ""

}

],

"unknown_rank": false

},

"name": "head/predictions/logistic:0"

},

"class_ids": {

"dtype": "DT_INT64",

"tensor_shape": {

"dim": [{

"size": "-1",

"name": ""

},

{

"size": "1",

"name": ""

}

],

"unknown_rank": false

},

"name": "head/predictions/ExpandDims:0"

},

"probabilities": {

"dtype": "DT_FLOAT",

"tensor_shape": {

"dim": [{

"size": "-1",

"name": ""

},

{

"size": "2",

"name": ""

}

],

"unknown_rank": false

},

"name": "head/predictions/probabilities:0"

},

"classes": {

"dtype": "DT_STRING",

"tensor_shape": {

"dim": [{

"size": "-1",

"name": ""

},

{

"size": "1",

"name": ""

}

],

"unknown_rank": false

},

"name": "head/predictions/str_classes:0"

},

"logits": {

"dtype": "DT_FLOAT",

"tensor_shape": {

"dim": [{

"size": "-1",

"name": ""

},

{

"size": "1",

"name": ""

}

],

"unknown_rank": false

},

"name": "add:0"

}

},

"method_name": "tensorflow/serving/predict"

},

"classification": {

"inputs": {

"inputs": {

"dtype": "DT_STRING",

"tensor_shape": {

"dim": [{

"size": "-1",

"name": ""

}],

"unknown_rank": false

},

"name": "input_example_tensor:0"

}

},

"outputs": {

"classes": {

"dtype": "DT_STRING",

"tensor_shape": {

"dim": [{

"size": "-1",

"name": ""

},

{

"size": "2",

"name": ""

}

],

"unknown_rank": false

},

"name": "head/Tile:0"

},

"scores": {

"dtype": "DT_FLOAT",

"tensor_shape": {

"dim": [{

"size": "-1",

"name": ""

},

{

"size": "2",

"name": ""

}

],

"unknown_rank": false

},

"name": "head/predictions/probabilities:0"

}

},

"method_name": "tensorflow/serving/classify"

},

"regression": {

"inputs": {

"inputs": {

"dtype": "DT_STRING",

"tensor_shape": {

"dim": [{

"size": "-1",

"name": ""

}],

"unknown_rank": false

},

"name": "input_example_tensor:0"

}

},

"outputs": {

"outputs": {

"dtype": "DT_FLOAT",

"tensor_shape": {

"dim": [{

"size": "-1",

"name": ""

},

{

"size": "1",

"name": ""

}

],

"unknown_rank": false

},

"name": "head/predictions/logistic:0"

}

},

"method_name": "tensorflow/serving/regress"

},

"serving_default": {

"inputs": {

"inputs": {

"dtype": "DT_STRING",

"tensor_shape": {

"dim": [{

"size": "-1",

"name": ""

}],

"unknown_rank": false

},

"name": "input_example_tensor:0"

}

},

"outputs": {

"classes": {

"dtype": "DT_STRING",

"tensor_shape": {

"dim": [{

"size": "-1",

"name": ""

},

{

"size": "2",

"name": ""

}

],

"unknown_rank": false

},

"name": "head/Tile:0"

},

"scores": {

"dtype": "DT_FLOAT",

"tensor_shape": {

"dim": [{

"size": "-1",

"name": ""

},

{

"size": "2",

"name": ""

}

],

"unknown_rank": false

},

"name": "head/predictions/probabilities:0"

}

},

"method_name": "tensorflow/serving/classify"

}

}

}

}

}

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值