TensorFlow 2 saveModel模型JAVA 加载及调用

名词释义

1、 saveModel模型:SavedModel 是更全面的保存格式,它可以保存模型架构、权重和调用函数的跟踪 Tensorflow 子计算图。保存模型和模型的层时,SavedModel 格式会存储类名称、调用函数、损失和权重(如果已实现,还包括配置)。调用函数会定义模型/层的计算图。
2、张量(Tensor): 张量是具有统一类型(称为 dtype)的多维数组, :无法更新,只能创建新的张量.其定义格式为:Tensor([1 0], shape=(2,), dtype=int64),第一个参数为多维数组,参数二为数组维度信息,参数三为数组类型。
3、计算图(Graph):计算图是包含一组 tf.Operation 对象(表示计算单元)和 tf.Tensor 对象(表示在运算之间流动的数据单元)的数据结构

saveModel模型加载

TensorFlow 模型的加载和模型生成的版本最好吻合, 否则容易出现各种导致加载失败或是运行失败的bug.

Maven包引入

TensorFlow 2.x 版本的jar包不再沿用以前的包名 ,其版本对应关系如下:
TensorFlow和jar包版本对应关系
maven pom 引用

	  <dependency>
            <groupId>org.tensorflow</groupId>
            <artifactId>tensorflow-core-api</artifactId>
            <version>0.5.0</version>
        </dependency>
        <dependency>
            <groupId>org.tensorflow</groupId>
            <artifactId>tensorflow-core-api</artifactId>
            <version>0.5.0</version>
            <classifier>linux-x86_64</classifier>
        </dependency>
        <dependency>
            <groupId>org.tensorflow</groupId>
            <artifactId>tensorflow-core-api</artifactId>
            <version>0.5.0</version>
            <classifier>windows-x86_64</classifier>
        </dependency>

模型加载

  		SavedModelBundle savedModelBundle = SavedModelBundle.load("C:\\Users\\Admin\\Desktop\\xxx_savedmodel","serve");
        Map<String, SignatureDef> signatureDefMap = MetaGraphDef.parseFrom(savedModelBundle.metaGraphDef().toByteArray()).getSignatureDefMap();
        /**
         * 获取基本定义信息
         */
        SignatureDef modelSig = signatureDefMap.get("serving_default");
        int numInputs = modelSig.getInputsCount();
        String inputTensorName = modelSig.getInputsMap().get("input_1").getName();
        String outputTensorName = modelSig.getOutputsMap().get("dense_3").getName();
        System.out.println(String.format("numInputs: %d, inputTensorName: %s, outputTensor: %s", numInputs, inputTensorName, outputTensorName));

输出如下

numInputs: 1, inputTensorName: serving_default_input_1:0, outputTensor: StatefulPartitionedCall:0

参数说明:
1.serving_default: 来源于模型基本信息signature_def[‘serving_default’]
2.input_1: 来源于模型输入字段名inputs[‘input_1’]
3.dense_3: 来源于模型输出字段名outputs[‘dense_3’]

模型的基本信息可以通过saved_model_cli 查看;

调用模型

调用参数组装

根据实际入参结构组装参数

 		LongNdArray matrix3d = NdArrays.ofLongs(Shape.of(1,100,9));
        //假定参数
        matrix3d.elements(0).forEach(matrix -> {
            matrix
                    .set(NdArrays.vectorOf(1l, 2l,3l,2l,3l,2l,3l,2l,3l), 0)
                    .set(NdArrays.vectorOf(3l, 4l,6l,2l,3l,2l,3l,2l,3l), 1);
        });
        TInt64 rank3Tensor = TInt64.tensorOf(matrix3d);

调用模型

		Graph graph = savedModelBundle.graph();
        try(Session session = savedModelBundle.session()){
            Result run = session.runner()
                    .feed(inputTensorName,rank3Tensor)
                    .fetch(outputTensorName)
                    .run();
            Tensor out = run.get(0);
            Shape shape = out.shape();

            System.out.println(shape);

输出结果如下

[1, 3]

修改模型入参类型

在未指定模型输入参数类型的时,模型默认入参为INT64, 此时可能需要使用python 修改模型入参类型,并重新导出 , python 代码如下 :

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models

print(tf.__version__)

load_path = 'C:\\Users\\Admin\\Desktop\\tf_model_savedmodel_new'
save_path = 'C:\\Users\\Admin\\Desktop\\tf_model_savedmodel_new_re'
//加载模型
model = tf.saved_model.load(load_path)
//获取对应模型签名
signatureDefault = model.signatures['serving_default']
print(signatureDefault.inputs[0])
//设置模型入参类型
input_layer = layers.Input(shape=(None,100,9), dtype=tf.float32)
dense_layer = layers.Dense(units=3, activation='relu')(input_layer)
model = models.Model(inputs=[input_layer], outputs=[dense_layer])
model.save(save_path)
  • 2
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值