class Tensor {
public:
Tensor();
Tensor(DataType type, const TensorShape& shape);
Tensor(Allocator* a, DataType type, const TensorShape& shape);
Tensor(Allocator* a, DataType type, const TensorShape& shape,
const AllocationAttributes& allocation_attr);
explicit Tensor(DataType type);
/// Copy constructor.
Tensor(const Tensor& other);
Tensor(Tensor&& other);
~Tensor();
DataType dtype() const { return shape_.data_type(); }
const TensorShape& shape() const { return shape_; }
int dims() const { return shape().dims(); }
int64 dim_size(int d) const { return shape().dim_size(d); }
int64 NumElements() const { return shape().num_elements(); }
bool IsSameSize(const Tensor& b) const {
return shape().IsSameSize(b.shape());
}
bool SharesBufferWith(const Tensor& b) const;
bool IsInitialized() const;
size_t TotalBytes() const;
size_t AllocatedBytes() const;
bool IsAligned() const {
#if EIGEN_MAX_ALIGN_BYTES == 0
return true;
#else
void* ptr = base<void>();
return reinterpret_cast<intptr_t>(ptr) % EIGEN_MAX_ALIGN_BYTES == 0;
#endif
}
Tensor& operator=(const Tensor& other) {
CopyFromInternal(other, other.shape());
return *this;
}
/// Move operator. See move constructor for details.
Tensor& operator=(Tensor&& other);
bool CopyFrom(const Tensor& other,
const TensorShape& shape) TF_MUST_USE_RESULT {
if (other.NumElements() != shape.num_elements()) return false;
CopyFromInternal(other, shape);
return true;
}
Tensor Slice(int64 dim0_start, int64 dim0_limit) const;
bool FromProto(const TensorProto& other) TF_MUST_USE_RESULT;
bool FromProto(Allocator* a, const TensorProto& other) TF_MUST_USE_RESULT;
void AsProtoField(TensorProto* proto) const;
void AsProtoTensorContent(TensorProto* proto) const;
template <typename T> typename TTypes<T>::Vec vec() {
return tensor<T, 1>();
}
template <typename T> typename TTypes<T>::Matrix matrix() {
return tensor<T, 2>();
}
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::Tensor tensor();
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::Tensor bit_casted_tensor();
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::Tensor reinterpret_last_dimension();
template <typename T>
typename TTypes<T>::Flat flat() {
return shaped<T, 1>({NumElements()});
}
template <typename T>
typename TTypes<T>::UnalignedFlat unaligned_flat() {
return unaligned_shaped<T, 1>({NumElements()});
}
template <typename T, size_t NDIMS = 2>
typename TTypes<T, NDIMS>::Tensor flat_inner_dims();
template <typename T, size_t NDIMS = 2>
typename TTypes<T, NDIMS>::Tensor flat_outer_dims();
template <typename T, size_t NDIMS = 3>
typename TTypes<T, NDIMS>::Tensor flat_inner_outer_dims(int64 begin);
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::Tensor shaped(gtl::ArraySlice<int64> new_sizes);
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::Tensor bit_casted_shaped(
gtl::ArraySlice<int64> new_sizes);
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::UnalignedTensor unaligned_shaped(
gtl::ArraySlice<int64> new_sizes);
template <typename T> typename TTypes<T>::Scalar scalar();
/// Const versions of all the methods above.
template <typename T>
typename TTypes<T>::ConstVec vec() const {
return tensor<T, 1>();
}
template <typename T>
typename TTypes<T>::ConstMatrix matrix() const {
return tensor<T, 2>();
}
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::ConstTensor tensor() const;
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::ConstTensor bit_casted_tensor() const;
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::ConstTensor reinterpret_last_dimension() const;
template <typename T>
typename TTypes<T>::ConstFlat flat() const {
return shaped<T, 1>({NumElements()});
}
template <typename T>
typename TTypes<T>::UnalignedConstFlat unaligned_flat() const {
return unaligned_shaped<T, 1>({NumElements()});
}
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::ConstTensor shaped(
gtl::ArraySlice<int64> new_sizes) const;
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::ConstTensor bit_casted_shaped(
gtl::ArraySlice<int64> new_sizes) const;
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::UnalignedConstTensor unaligned_shaped(
gtl::ArraySlice<int64> new_sizes) const;
template <typename T>
typename TTypes<T>::ConstScalar scalar() const;
template <typename T, size_t NDIMS = 2>
typename TTypes<T, NDIMS>::ConstTensor flat_inner_dims() const;
template <typename T, size_t NDIMS = 2>
typename TTypes<T, NDIMS>::ConstTensor flat_outer_dims() const;
template <typename T, size_t NDIMS = 3>
typename TTypes<T, NDIMS>::ConstTensor flat_inner_outer_dims(int64 begin) const;
string SummarizeValue(int64 max_entries) const;
string DebugString() const;
void FillDescription(TensorDescription* description) const;
StringPiece tensor_data() const;
void UnsafeCopyFromInternal(const Tensor&, DataType dtype,
const TensorShape&);
private:
bool RefCountIsOne() const;
void CheckType(DataType expected_dtype) const;
void CheckTypeAndIsAligned(DataType expected_dtype) const;
void CheckIsAlignedAndSingleElement() const;
void set_dtype(DataType t) { shape_.set_data_type(t); }
template <size_t NDIMS>
void FillDimsAndValidateCompatibleShape(
gtl::ArraySlice<int64> new_sizes,
Eigen::array<Eigen::DenseIndex, NDIMS>* dims) const;
static gtl::InlinedVector<int64, 4> ComputeFlatInnerDims(
gtl::ArraySlice<int64> orig, int64 num_out_dims);
static gtl::InlinedVector<int64, 4> ComputeFlatOuterDims(
gtl::ArraySlice<int64> orig, int64 num_out_dims);
TensorShape shape_;
TensorBuffer* buf_;
friend class DMAHelper;
friend class TensorCApi;
friend class TensorReference;
friend class VariableOp;
friend class AutoReloadVariableOp;
friend class TensorTestHelper;
template <typename Device, typename T>
friend class CreateVariableOp;
friend class OpKernelContext;
friend class NumpyTensorBuffer;
Tensor(DataType type, const TensorShape& shape, TensorBuffer* buf);
bool CanUseDMA() const;
void set_shape(const TensorShape& shape) {
DataType dt = dtype();
shape_ = shape;
set_dtype(dt);
}
void CopyFromInternal(const Tensor& other, const TensorShape& shape);
template <typename T> T* base() const;
template <size_t NDIMS>
void FillDimsAndValidateCompatibleShape(
Eigen::array<Eigen::DenseIndex, NDIMS>* dims,
gtl::ArraySlice<int64> new_sizes) const;
};