1.经测试,java好像目前只能加载python直接save的model,断点权重model不行。model类型为pd.
2.java加载model代码样例
public static void main(String[] args) throws Exception {
SavedModelBundle bundle = SavedModelBundle.load("E:\\code\\python\\tensorflow\\filenew3\\", "serve");
float[] a = new float[]{23.8689218f, 18.37901895f, 0.22903802f, 0.32341023f, 0.87688538f, 0.67953317f, 0.89015956f, 4.21707508f};
float[][] b = new float[2][8];
b[0][0] = 23.8689218f;
b[0][1] = 18.37901895f;
b[0][2] = 0.22903802f;
b[0][3] = 0.32341023f;
b[0][4] = 0.87688538f;
b[0][5] = 0.67953317f;
b[0][6] = 0.89015956f;
b[0][7] = 4.21707508f;
b[1][0] = 23.8689218f;
b[1][1] = 18.37901895f;
b[1][2] = 0.22903802f;
b[1][3] = 0.32341023f;
b[1][4] = 0.87688538f;
b[1][5] = 0.67953317f;
b[1][6] = 0.89015956f;
b[1][7] = 4.21707508f;
getResult(bundle, b);
}
public static int getResult(SavedModelBundle bundle, float[][] arr){
long from = System.currentTimeMillis();
Tensor tensor= Tensor.create(arr);
List<Tensor<?>> result= bundle.session().runner().feed("serving_default_dense_input:0",tensor).fetch("StatefulPartitionedCall:0").run();
Tensor t = result.get(0);
float[][] resultValues = (float[][])t.copyTo(new float[2][17]);
int Max = 0;
for (int i = 1; i < 17; i++) {
if (resultValues[0][i] > resultValues[0][Max]) {
Max = i;
}
}
System.out.println(Max);
tensor.close();
t.close();
long to = System.currentTimeMillis();
System.out.println(from);
System.out.println(to);
System.out.println(to - from);
return Max;
}
3. 输入输出名称不清楚,可以用如下方法
SavedModelBundle bundle = SavedModelBundle.load("E:\\code\\python\\tensorflow\\filenew3\\", "serve");
MetaGraphDef metaGraphDef = null;
try {
metaGraphDef = MetaGraphDef.parseFrom(bundle.metaGraphDef());
} catch (InvalidProtocolBufferException e) {
e.printStackTrace();
}
//用如下方法获取输入输出参数名
final SignatureDef signatureDef = metaGraphDef.getSignatureDefMap().get("serving_default");
Map map1 = signatureDef.getInputsMap();
final TensorInfo inputTensorInfo = signatureDef.getInputsMap()
.values()
.stream()
.filter(Objects::nonNull)
.findFirst()
.orElseThrow(Exception::new);
System.out.println(inputTensorInfo.getName());
Map map2 = signatureDef.getOutputsMap();
final TensorInfo outputTensorInfo = signatureDef.getOutputsMap()
.values()
.stream()
.filter(Objects::nonNull)
.findFirst()
.orElseThrow(Exception::new);
System.out.println(outputTensorInfo.getName());