名词释义
1、 saveModel模型:SavedModel 是更全面的保存格式,它可以保存模型架构、权重和调用函数的跟踪 Tensorflow 子计算图。保存模型和模型的层时,SavedModel 格式会存储类名称、调用函数、损失和权重(如果已实现,还包括配置)。调用函数会定义模型/层的计算图。
2、张量(Tensor): 张量是具有统一类型(称为 dtype)的多维数组, :无法更新,只能创建新的张量.其定义格式为:Tensor([1 0], shape=(2,), dtype=int64),第一个参数为多维数组,参数二为数组维度信息,参数三为数组类型。
3、计算图(Graph):计算图是包含一组 tf.Operation 对象(表示计算单元)和 tf.Tensor 对象(表示在运算之间流动的数据单元)的数据结构
saveModel模型加载
TensorFlow 模型的加载和模型生成的版本最好吻合, 否则容易出现各种导致加载失败或是运行失败的bug.
Maven包引入
TensorFlow 2.x 版本的jar包不再沿用以前的包名 ,其版本对应关系如下:
maven pom 引用
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-core-api</artifactId>
<version>0.5.0</version>
</dependency>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-core-api</artifactId>
<version>0.5.0</version>
<classifier>linux-x86_64</classifier>
</dependency>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-core-api</artifactId>
<version>0.5.0</version>
<classifier>windows-x86_64</classifier>
</dependency>
模型加载
SavedModelBundle savedModelBundle = SavedModelBundle.load("C:\\Users\\Admin\\Desktop\\xxx_savedmodel","serve");
Map<String, SignatureDef> signatureDefMap = MetaGraphDef.parseFrom(savedModelBundle.metaGraphDef().toByteArray()).getSignatureDefMap();
/**
* 获取基本定义信息
*/
SignatureDef modelSig = signatureDefMap.get("serving_default");
int numInputs = modelSig.getInputsCount();
String inputTensorName = modelSig.getInputsMap().get("input_1").getName();
String outputTensorName = modelSig.getOutputsMap().get("dense_3").getName();
System.out.println(String.format("numInputs: %d, inputTensorName: %s, outputTensor: %s", numInputs, inputTensorName, outputTensorName));
输出如下
numInputs: 1, inputTensorName: serving_default_input_1:0, outputTensor: StatefulPartitionedCall:0
参数说明:
1.serving_default: 来源于模型基本信息signature_def[‘serving_default’]
2.input_1: 来源于模型输入字段名inputs[‘input_1’]
3.dense_3: 来源于模型输出字段名outputs[‘dense_3’]
模型的基本信息可以通过saved_model_cli 查看;
调用模型
调用参数组装
根据实际入参结构组装参数
LongNdArray matrix3d = NdArrays.ofLongs(Shape.of(1,100,9));
//假定参数
matrix3d.elements(0).forEach(matrix -> {
matrix
.set(NdArrays.vectorOf(1l, 2l,3l,2l,3l,2l,3l,2l,3l), 0)
.set(NdArrays.vectorOf(3l, 4l,6l,2l,3l,2l,3l,2l,3l), 1);
});
TInt64 rank3Tensor = TInt64.tensorOf(matrix3d);
调用模型
Graph graph = savedModelBundle.graph();
try(Session session = savedModelBundle.session()){
Result run = session.runner()
.feed(inputTensorName,rank3Tensor)
.fetch(outputTensorName)
.run();
Tensor out = run.get(0);
Shape shape = out.shape();
System.out.println(shape);
输出结果如下
[1, 3]
修改模型入参类型
在未指定模型输入参数类型的时,模型默认入参为INT64, 此时可能需要使用python 修改模型入参类型,并重新导出 , python 代码如下 :
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
print(tf.__version__)
load_path = 'C:\\Users\\Admin\\Desktop\\tf_model_savedmodel_new'
save_path = 'C:\\Users\\Admin\\Desktop\\tf_model_savedmodel_new_re'
//加载模型
model = tf.saved_model.load(load_path)
//获取对应模型签名
signatureDefault = model.signatures['serving_default']
print(signatureDefault.inputs[0])
//设置模型入参类型
input_layer = layers.Input(shape=(None,100,9), dtype=tf.float32)
dense_layer = layers.Dense(units=3, activation='relu')(input_layer)
model = models.Model(inputs=[input_layer], outputs=[dense_layer])
model.save(save_path)