Java Tensorflow SaveModelBundle 运行模型

首先需要有训练好的模型。

引入tensorflow的包到pom中

        <dependency>
            <groupId>org.tensorflow</groupId>
            <artifactId>tensorflow</artifactId>
            <version>1.15.0</version>
        </dependency>
ConfigProto config = ConfigProto.newBuilder()
                            .putDeviceCount("CPU", 8).setIntraOpParallelismThreads(50)
                            .setInterOpParallelismThreads(50)
                            .build();
// 增加模型运行配置

CofigProto在SaveModelBundle初始化时加载。源码:

    public Loader withConfigProto(byte[] configProto) {
      this.configProto = configProto;
      return this;
    }

创建SaveModelBundle实例化,构造模型对象。

SavedModelBundle.loader(modelUri).withTags("serve").withConfigProto(config.toByteArray()).load();

调用的native方法。

  private static native SavedModelBundle load(
      String exportDir, String[] tags, byte[] config, byte[] runOptions);

下面就是跑模型的方法了。

// feed为opName,参数 fetch为opName
public static float[][] run(SavedModelBundle savedModelBundle, Map<String, float[][]> feed, List<String> fetch) {
        float[][] r = new float[0][0];

        if (Objects.isNull(savedModelBundle)) {
            return r;
        }
        List<Tensor> tensorList = Lists.newArrayList();
        try {
            Session.Runner runner = savedModelBundle.session().runner();
            int size = 0;
            // 将feed opNmae和参数构造tensor,传入session中
            for (String feedOpName : feed.keySet()) {
                float[][] floats = feed.get(feedOpName);
                Tensor<?> inputTensor = Tensor.create(floats);
                tensorList.add(inputTensor);
                runner.feed(feedOpName, inputTensor);
                size = floats.length;
            }
            // fetch中opName传入session中
            for (String fetchOpName : fetch) {
                runner.fetch(fetchOpName);
            }
            // run模型
            List<Tensor<?>> out = runner.run();
            Tensor<?> tensor = out.get(0);
            tensorList.addAll(out);
            r = new float[size][1];
            tensor.copyTo(r);
        } catch (Exception ex) {
            logger.error(ex.getMessage(), ex);
        } finally {
            //结束时关闭tensor,tensor会占用内存,不关闭的话,容易出现oom。
           for (Tensor tensorClose : tensorList) {
               tensorClose.close();
           }
        }
        return r;
    }

Tensor源码:

/**
   * Create a Tensor of data type {@code dtype} from a Java object. Requires the parameter {@code T}
   * to match {@code type}, but this condition is not checked.
   *
   * @param obj the object supplying the tensor data.
   * @param dtype the data type of the tensor to create. It must be compatible with the run-time
   *     type of the object.
   * @return the new tensor
   */
  private static Tensor<?> create(Object obj, DataType dtype) {
    @SuppressWarnings("rawtypes")
    Tensor<?> t = new Tensor(dtype);
    t.shapeCopy = new long[numDimensions(obj, dtype)];
    fillShape(obj, 0, t.shapeCopy);
    long nativeHandle;
    if (t.dtype != DataType.STRING) {
      int byteSize = elemByteSize(t.dtype) * numElements(t.shapeCopy);
      nativeHandle = allocate(t.dtype.c(), t.shapeCopy, byteSize);
      setValue(nativeHandle, obj);
    } else if (t.shapeCopy.length != 0) {
      nativeHandle = allocateNonScalarBytes(t.shapeCopy, (Object[]) obj);
    } else {
      nativeHandle = allocateScalarBytes((byte[]) obj);
    }
    t.nativeRef = new NativeReference(nativeHandle);
    return t;
  }

session run 的源码:

private Run runHelper(boolean wantMetadata) {
      long[] inputTensorHandles = new long[inputTensors.size()];
      long[] inputOpHandles = new long[inputs.size()];
      int[] inputOpIndices = new int[inputs.size()];
      long[] outputOpHandles = new long[outputs.size()];
      int[] outputOpIndices = new int[outputs.size()];
      long[] targetOpHandles = new long[targets.size()];
      long[] outputTensorHandles = new long[outputs.size()];

      // It's okay to use Operation.getUnsafeNativeHandle() here since the safety depends on the
      // validity of the Graph and graphRef ensures that.
      int idx = 0;
      for (Tensor<?> t : inputTensors) {
        inputTensorHandles[idx++] = t.getNativeHandle();
      }
      idx = 0;
      for (Output<?> o : inputs) {
        inputOpHandles[idx] = o.getUnsafeNativeHandle();
        inputOpIndices[idx] = o.index();
        idx++;
      }
      idx = 0;
      for (Output<?> o : outputs) {
        outputOpHandles[idx] = o.getUnsafeNativeHandle();
        outputOpIndices[idx] = o.index();
        idx++;
      }
      idx = 0;
      for (GraphOperation op : targets) {
        targetOpHandles[idx++] = op.getUnsafeNativeHandle();
      }
      Reference runRef = new Reference();
      byte[] metadata = null;
      try {
        metadata =
            Session.run(
                nativeHandle,
                runOptions,
                inputTensorHandles,
                inputOpHandles,
                inputOpIndices,
                outputOpHandles,
                outputOpIndices,
                targetOpHandles,
                wantMetadata,
                outputTensorHandles);
      } finally {
        runRef.close();
      }
      List<Tensor<?>> outputs = new ArrayList<Tensor<?>>();
      for (long h : outputTensorHandles) {
        try {
          outputs.add(Tensor.fromHandle(h));
        } catch (Exception e) {
          for (Tensor<?> t : outputs) {
            t.close();
          }
          outputs.clear();
          throw e;
        }
      }
      Run ret = new Run();
      ret.outputs = outputs;
      ret.metadata = metadata;
      return ret;
    }

 private static native byte[] run(
      long handle,
      byte[] runOptions,
      long[] inputTensorHandles,
      long[] inputOpHandles,
      int[] inputOpIndices,
      long[] outputOpHandles,
      int[] outputOpIndices,
      long[] targetOpHandles,
      boolean wantRunMetadata,
      long[] outputTensorHandles);
}

 

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值