java 加载tensorflow model

1.经测试,java好像目前只能加载python直接save的model,断点权重model不行。model类型为pd.

2.java加载model代码样例

public static void main(String[] args) throws Exception {

        SavedModelBundle bundle = SavedModelBundle.load("E:\\code\\python\\tensorflow\\filenew3\\", "serve");

        float[] a = new float[]{23.8689218f, 18.37901895f,  0.22903802f,  0.32341023f,  0.87688538f,  0.67953317f, 0.89015956f,  4.21707508f};
        float[][] b = new float[2][8];
        b[0][0] = 23.8689218f;
        b[0][1] = 18.37901895f;
        b[0][2] = 0.22903802f;
        b[0][3] = 0.32341023f;
        b[0][4] = 0.87688538f;
        b[0][5] = 0.67953317f;
        b[0][6] = 0.89015956f;
        b[0][7] = 4.21707508f;
        b[1][0] = 23.8689218f;
        b[1][1] = 18.37901895f;
        b[1][2] = 0.22903802f;
        b[1][3] = 0.32341023f;
        b[1][4] = 0.87688538f;
        b[1][5] = 0.67953317f;
        b[1][6] = 0.89015956f;
        b[1][7] = 4.21707508f;

        getResult(bundle, b);
}

 public  static int  getResult(SavedModelBundle bundle, float[][] arr){
        long from = System.currentTimeMillis();
        Tensor  tensor= Tensor.create(arr);
        List<Tensor<?>> result= bundle.session().runner().feed("serving_default_dense_input:0",tensor).fetch("StatefulPartitionedCall:0").run();
        Tensor t = result.get(0);
        float[][] resultValues = (float[][])t.copyTo(new float[2][17]);
        int Max = 0;
        for (int i = 1; i < 17; i++) {
            if (resultValues[0][i] > resultValues[0][Max]) {
                Max = i;
            }
        }

        System.out.println(Max);
        tensor.close();
        t.close();
        long to = System.currentTimeMillis();
        System.out.println(from);
        System.out.println(to);
        System.out.println(to - from);
        return Max;
    }

3. 输入输出名称不清楚,可以用如下方法

SavedModelBundle bundle = SavedModelBundle.load("E:\\code\\python\\tensorflow\\filenew3\\", "serve");      
  MetaGraphDef metaGraphDef = null;
        try {
            metaGraphDef = MetaGraphDef.parseFrom(bundle.metaGraphDef());
        } catch (InvalidProtocolBufferException e) {
            e.printStackTrace();
        }
        //用如下方法获取输入输出参数名
        final SignatureDef signatureDef = metaGraphDef.getSignatureDefMap().get("serving_default");
        Map map1 = signatureDef.getInputsMap();

        final TensorInfo inputTensorInfo = signatureDef.getInputsMap()
                .values()
                .stream()
                .filter(Objects::nonNull)
                .findFirst()
                .orElseThrow(Exception::new);
        System.out.println(inputTensorInfo.getName());
        Map map2 = signatureDef.getOutputsMap();
        final TensorInfo outputTensorInfo = signatureDef.getOutputsMap()
                .values()
                .stream()
                .filter(Objects::nonNull)
                .findFirst()
                .orElseThrow(Exception::new);
        System.out.println(outputTensorInfo.getName());

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值