03_Tensor.java 和 Tensors.java 源码

 1、调用分析

//1、初始化 tensor,并在初始化 tensor 时调用tensorflow.java api,初始化 tensorflow 环境
static {
    TensorFlow.init();   // Tensor 变量初始化,初始化了 Tensorflow 环境
}

 创建 Tensor

//使用以下函数创建 Tensor
Tensor<Float> imageTensor = Tensors.create(rgbFloat);  //用图像数据初始化 Tensor 变量


//From Tensors.java 

/**
   * Creates a rank-4 tensor of {@code float} elements.
   *
   * @param data An array containing the values to put into the new tensor. The dimensions of the
   *     new tensor will match those of the array.
   */
  public static Tensor<Float> create(float[][][][] data) {
    return Tensor.create(data, Float.class);    //From Tensor.java
  }


//From Tensor.java
/**
   * Creates a Tensor from a Java object.
   *
   * <p>A {@code Tensor} is a multi-dimensional array of elements of a limited set of types. Not all
   * Java objects can be converted to a {@code Tensor}. In particular, the argument {@code obj} must
   * be either a primitive (float, double, int, long, boolean, byte) or a multi-dimensional array of
   * one of those primitives. The argument {@code type} specifies how to interpret the first
   * argument as a TensorFlow type. For example:
   *
   * <pre>{@code
   * // Valid: A 64-bit integer scalar.
   * Tensor<Long> s = Tensor.create(42L, Long.class);
   *
   * // Valid: A 3x2 matrix of floats.
   * float[][] matrix = new float[3][2];
   * Tensor<Float> m = Tensor.create(matrix, Float.class);
   *
   * // Invalid: Will throw an IllegalArgumentException as an arbitrary Object
   * // does not fit into the TensorFlow type system.
   * Tensor<?> o = Tensor.create(new Object())
   *
   * // Invalid: Will throw an IllegalArgumentException since there are
   * // a differing number of elements in each row of this 2-D array.
   * int[][] twoD = new int[2][];
   * twoD[0] = new int[1];
   * twoD[1] = new int[2];
   * Tensor<Integer> x = Tensor.create(twoD, Integer.class);
   * }</pre>
   *
   * {@link String}-typed Tensors are multi-dimensional arrays of arbitrary byte sequences, so can
   * be initialized from arrays of {@code byte[]} elements. For example:
   *
   * <pre>{@code
   * // Valid: A String tensor.
   * Tensor<String> s = Tensor.create(new byte[]{1, 2, 3}, String.class);
   *
   * // Java Strings will need to be encoded into a byte-sequence.
   * String mystring = "foo";
   * Tensor<String> s = Tensor.create(mystring.getBytes("UTF-8"), String.class);
   *
   * // Valid: Matrix of String tensors.
   * // Each element might have a different length.
   * byte[][][] matrix = new byte[2][2][];
   * matrix[0][0] = "this".getBytes("UTF-8");
   * matrix[0][1] = "is".getBytes("UTF-8");
   * matrix[1][0] = "a".getBytes("UTF-8");
   * matrix[1][1] = "matrix".getBytes("UTF-8");
   * Tensor<String> m = Tensor.create(matrix, String.class);
   * }</pre>
   *
   * @param obj The object to convert to a {@code Tensor<T>}. Note that whether it is compatible
   *     with the type T is not checked by the type system. For type-safe creation of tensors, use
   *     {@link Tensors}.
   * @param type The class object representing the type T.
   * @throws IllegalArgumentException if {@code obj} is not compatible with the TensorFlow type
   *     system.
   */
  @SuppressWarnings("unchecked")
  public static <T> Tensor<T> create(Object obj, Class<T> type) {
    DataType dtype = DataType.fromClass(type);
    if (!objectCompatWithType(obj, dtype)) {
      throw new IllegalArgumentException(
          "DataType of object does not match T (expected "
              + dtype
              + ", got "
              + dataTypeOf(obj)
              + ")");
    }
    return (Tensor<T>) create(obj, dtype);
  }


//From Tensor.java
/**
   * Create a Tensor of data type {@code dtype} from a Java object. Requires the parameter {@code T}
   * to match {@code type}, but this condition is not checked.
   *
   * @param obj the object supplying the tensor data.
   * @param dtype the data type of the tensor to create. It must be compatible with the run-time
   *     type of the object.
   * @return the new tensor
   */
  private static Tensor<?> create(Object obj, DataType dtype) {
    @SuppressWarnings("rawtypes")
    Tensor<?> t = new Tensor(dtype);        //初始化 Tensor
    t.shapeCopy = new long[numDimensions(obj, dtype)];
    fillShape(obj, 0, t.shapeCopy);
    long nativeHandle;
    if (t.dtype != DataType.STRING) {
      int byteSize = elemByteSize(t.dtype) * numElements(t.shapeCopy);
      nativeHandle = allocate(t.dtype.c(), t.shapeCopy, byteSize);    //调用该 native API
      setValue(nativeHandle, obj);
    } else if (t.shapeCopy.length != 0) {
      nativeHandle = allocateNonScalarBytes(t.shapeCopy, (Object[]) obj);
    } else {
      nativeHandle = allocateScalarBytes((byte[]) obj);
    }
    t.nativeRef = new NativeReference(nativeHandle);
    return t;
  }


// 实现在 tensor_jni.cc
private static native long allocate(int dtype, long[] shape, long byteSize);

//tensor_jni.cc
JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocate(JNIEnv* env,
                                                            jclass clazz,
                                                            jint dtype,
                                                            jlongArray shape,
                                                            jlong sizeInBytes) {
  int num_dims = static_cast<int>(env->GetArrayLength(shape));
  jlong* dims = nullptr;
  if (num_dims > 0) {
    jboolean is_copy;
    dims = env->GetLongArrayElements(shape, &is_copy);
  }
  static_assert(sizeof(jlong) == sizeof(int64_t),
                "Java long is not compatible with the TensorFlow C API");
  // On some platforms "jlong" is a "long" while "int64_t" is a "long long".
  //
  // Thus, static_cast<int64_t*>(dims) will trigger a compiler error:
  // static_cast from 'jlong *' (aka 'long *') to 'int64_t *' (aka 'long long
  // *') is not allowed
  //
  // Since this array is typically very small, use the guaranteed safe scheme of
  // creating a copy.
  int64_t* dims_copy = new int64_t[num_dims];
  for (int i = 0; i < num_dims; ++i) {
    dims_copy[i] = static_cast<int64_t>(dims[i]);
  }

  //调用 tensorflow/tensorflow/c/tf_tensor.cc 中的 TF_AllocateTensor
  TF_Tensor* t = TF_AllocateTensor(static_cast<TF_DataType>(dtype), dims_copy,
                                   num_dims, static_cast<size_t>(sizeInBytes));
  delete[] dims_copy;
  if (dims != nullptr) {
    env->ReleaseLongArrayElements(shape, dims, JNI_ABORT);
  }
  if (t == nullptr) {
    throwException(env, kNullPointerException,
                   "unable to allocate memory for the Tensor");
    return 0;
  }
  return reinterpret_cast<jlong>(t);
}


//From tensorflow/tensorflow/c/tf_tensor.cc
TF_Tensor* TF_AllocateTensor(TF_DataType dtype, const int64_t* dims,
                             int num_dims, size_t len) {
  void* data = tensorflow::allocate_tensor("TF_AllocateTensor", len,
                                           tensorflow::cpu_allocator());
  return TF_NewTensor(dtype, dims, num_dims, data, len,
                      tensorflow::deallocate_buffer,
                      tensorflow::cpu_allocator());
}

//From tensorflow/tensorflow/c/tf_tensor.cc
TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
                        void* data, size_t len,
                        void (*deallocator)(void* data, size_t len, void* arg),
                        void* deallocator_arg) {
  std::vector<tensorflow::int64> dimvec(num_dims);
  for (int i = 0; i < num_dims; ++i) {
    dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
  }

  TF_ManagedBuffer* buf = nullptr;
  if (dtype != TF_STRING && dtype != TF_RESOURCE &&
      tensorflow::DataTypeCanUseMemcpy(
          static_cast<tensorflow::DataType>(dtype)) &&
      reinterpret_cast<intptr_t>(data) % std::max(1, EIGEN_MAX_ALIGN_BYTES) !=
          0) {
    // TF_STRING and TF_RESOURCE tensors have a different representation in
    // TF_Tensor than they do in tensorflow::Tensor. So a copy here is a waste
    // (any alignment requirements will be taken care of by TF_TensorToTensor
    // and TF_TensorFromTensor).
    //
    // Other types have the same representation, so copy only if it is safe to
    // do so.
    buf = new TF_ManagedBuffer(tensorflow::allocate_tensor("TF_NewTensor", len),
                               len, tensorflow::deallocate_buffer, nullptr);
    std::memcpy(buf->data(), data, len);
    // Free the original buffer.
    deallocator(data, len, deallocator_arg);
  } else {
    buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg);
  }

  // TODO(gjn): Make the choice of interface a compile-time configuration.
  tensorflow::TensorInterface ret(
      Tensor(static_cast<tensorflow::DataType>(dtype),
             tensorflow::TensorShape(dimvec), buf));
  buf->Unref();
  size_t elem_size = TF_DataTypeSize(dtype);
  if (elem_size > 0 && len < (elem_size * ret.NumElements())) {
    return nullptr;
  }
  return new TF_Tensor{std::make_unique<tensorflow::TensorInterface>(ret)};
}

//From tensorflow/tensorflow/c/tf_tensor_internal.h
//TensorBuffer :定义在 tensorflow/tensorflow/core/framework/tensor.h

class TF_ManagedBuffer : public tensorflow::TensorBuffer {
 public:
  TF_ManagedBuffer(void* data, size_t len,
                   void (*deallocator)(void* data, size_t len, void* arg),
                   void* deallocator_arg)
      : TensorBuffer(data),
        len_(len),
        deallocator_(deallocator),
        deallocator_arg_(deallocator_arg) {}

  ~TF_ManagedBuffer() override {
    (*deallocator_)(data(), len_, deallocator_arg_);
  }

  size_t size() const override { return len_; }
  TensorBuffer* root_buffer() override { return this; }
  void FillAllocationDescription(
      tensorflow::AllocationDescription* proto) const override {
    tensorflow::int64 rb = size();
    proto->set_requested_bytes(rb);
    proto->set_allocator_name(tensorflow::cpu_allocator()->Name());
  }

  // Prevents input forwarding from mutating this buffer.
  bool OwnsMemory() const override { return false; }

 private:
  const size_t len_;
  void (*const deallocator_)(void* data, size_t len, void* arg);
  void* const deallocator_arg_;
};
//定义在 tensorflow/tensorflow/c/tf_tensor_internal.h

// This struct forms part of the C API's public interface. It must strictly be
// passed to or returned from C functions *by pointer*. Otherwise, changes to
// its internal structure will break the C API's binary interface.
typedef struct TF_Tensor {
  std::unique_ptr<AbstractTensorInterface> tensor;
} TF_Tensor;

 

2、Tensor.java

/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obta
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值