[PyTorch 源码阅读] —— TensorIterator

前言

在介绍正式内容之前,来看一个简单的问题,如何将两个数组相加

从工程师的角度,写一个 for 循环,然后依次相加不是就起来就可以了。

看着很简单。但是,进一步,如果让它很快速,并且可以面对更多的情况呢?

TensorIterator 就是干这个工作的。为了应对不同的情况,并且有一定的性能优化,使得它变得很复杂。

如果给你两个 vector,让你把它们加起来,你可能会先问,这两个 vector 的size 是多少? 因为Tensor 是一个虚拟的多维的概念,所以这里两个 Tensor 相加不一定需要两个 tensor 是相同的 shape。 这就引出了 broadcast 的概念。然后相加的两个 tensor 也可能不是相同的 data type,如果不同的 tensor type 还需要 type promotion。 这两个都是来自 Numpy。 并且支持指定输出的 Dtype 类型,如果不一致也需要考虑。还有 memory overlap,大概就是通过 stride,不同的 index 访问的都是同一块物理内存地址。strides 来实现 broadcast,并且可以用不同的 strides 来解释同一块 dense 内存。同时还需要解决 output layout 的问题,例如如果输入是一个 colmajor 一个 rowmajor。那需要去决定 output 的 layout。

最后是如何是这个过程更高效。index 寻址需要遍历所有的维度,然后依次乘以 stride 计算出真实的物理内存的 offset,但是如果是同维度的tensor 相加,其实可以将它们看做是一维 tensor,然后直接相加,不需要 for 循环去遍历了。 需要去识别这种 case。

mem_overlap, 是指一个算子的输入和输出是不是有部分重叠(view / inplace),并检查输出内部是否重叠(broadcast view)。 外部重叠,主要是直接检查内存地址的大小。

TensorIterator 主要是用来辅助 element-wise 操作类算子的实现,例如算术计算,比较类函数和三角函数等。

TensorIteratorConfig

接下来从代码层面来介绍这个类的部分内容。来看一下上面的部分概念都是如何实现的。

首先要介绍的是 TensorIteratorConfig 类,这个类从名字就可以知道主要是用来做 TensorIterator 的配置用的,具体都会配置一点什么呢

class TORCH_API TensorIteratorConfig final {
public:
  friend struct TensorIteratorBase;
  friend struct TensorIterator;

  TensorIteratorConfig() {}

 // ...
 TensorIteratorConfig& check_all_same_dtype(const bool _check_all_same_dtype) {
    check_all_same_dtype_ = _check_all_same_dtype;
    return *this;
  }

 TensorIteratorConfig& check_all_same_device(const bool _check_all_same_device) {
    check_all_same_device_ = _check_all_same_device;
    return *this;
  }

 TensorIteratorConfig& enforce_safe_casting_to_output(const bool _enforce_safe_casting_to_output) {
    enforce_safe_casting_to_output_ = _enforce_safe_casting_to_output;
    return *this;
  }
 // ...
 TensorIterator build() {
    TensorIterator iter;
    iter.build(*this);
    return iter;
  }

private:
  SmallVector<c10::MaybeOwned<Tensor>, 4> tensors_; // 保存输入和输出 tensor
  int num_outputs_ = 0; // 输出个数
  int num_inputs_ = 0; // 输入个数

  c10::optional<DimVector> static_shape_ = c10::nullopt; // 指定输出 shape
  c10::optional<std::pair<ScalarType, Device>> static_dtype_and_device_ = c10::nullopt;
  bool check_mem_overlap_ = true; // 检查内存是否重叠
  bool allow_cpu_scalars_ = false;
  bool is_reduction_ = false; // 是不是 reduce op
  bool resize_outputs_ = true;
  bool check_all_same_dtype_ = true; // 检查输入输出是否是同一数据类型
  bool check_all_same_device_ = true; // 检查输入输出是否在相同设备上
  bool enforce_safe_casting_to_output_ = false;
  bool enforce_linear_iteration_ = false;
  bool promote_inputs_to_common_dtype_ = false; // 是否需要做 type promote
  bool promote_integer_inputs_to_float_ = false;
  bool cast_common_dtype_to_outputs_ = false; // 是否转化输出数据类型
};

可以从 TensorIteratorConfig 的成员变量看到,为了可以处理各种情况,设置了不同的标志位来明确指示对应行为,并且值得注意的一个点是,它的成员函数都直接 return *this。这样的一个好处就是可以链式调用,使配置更加清晰。例如下面:

#define BINARY_OP_CONFIG()                              \
  TensorIteratorConfig()                                \
    .set_check_mem_overlap(true)                        \
    .allow_cpu_scalars(true)                            \
    .promote_inputs_to_common_dtype(true)               \
    .cast_common_dtype_to_outputs(true)                 \
    .enforce_safe_casting_to_output(true)               \

void TensorIteratorBase::build_binary_op(const Tensor& out, const Tensor& a, const Tensor& b) {
  build(BINARY_OP_CONFIG()
      .add_owned_output(out)
      .add_owned_input(a)
      .add_owned_input(b));
}

TensorIteratorBase

通过上面的用法引出 TensorIteratorBase 类,这个类的成员函数以 TensorIteratorConfig 为基础来实现主要功能,下面是部分源码展示:

struct TORCH_API TensorIteratorBase : public impl::MetaBase {
  using DimMask = std::bitset<64>;
  using PtrVector = SmallVector<char*, 4>;
  using StrideVector = SmallVector<int64_t, 6>;

  TensorIteratorBase();
  void build(TensorIteratorConfig&);

  // ...

protected:
  // Mutable reference as it moves tensors out of TensorIteratorConfig
  void populate_operands(TensorIteratorConfig&); // 获取输入输出 tensor
  void mark_outputs();
  void mark_resize_outputs(const TensorIteratorConfig&); // 标记输出是否需要 resize
  void compute_mem_overlaps(const TensorIteratorConfig&); // 确认内存是否有重叠
  void compute_shape(const TensorIteratorConfig&);  // 计算新的输出 shape, broadcast 功能
  void compute_strides(const TensorIteratorConfig&);  // 计算新的 stride
  void reorder_dimensions(); // 对维度重新排序
  void permute_dimensions(IntArrayRef perm);
  void compute_types(const TensorIteratorConfig&); // 计算输出的 dtype 和 device
  ScalarType compute_common_dtype(); // 计算统一的 dtype,type promotion功能
  void allocate_or_resize_outputs(); // 如果没有提供输出,需要分配内存
  bool fast_set_up(const TensorIteratorConfig&);
  FastSetupType compute_fast_setup_type(const TensorIteratorConfig&);
  void compute_names(const TensorIteratorConfig&);
  void propagate_names_to_outputs();
  void coalesce_dimensions();

// ...
};

上面也是通过 TensorIteratorBase 类的成员函数,就可以看出它的复杂性,需要处理各种复杂的情况。这里面主要就介绍两个函数,分别与 broadcast 和 type promotion 有关。
broadcast 功能就是自动拓展维度,使计算可以得到正确结果,例如:

a=torch.randn(1,3,4)
b=torch.randn(4,1,1)
c = a + b // c.shape == [4,3,4]

其中主要功能实现在上面的 compute_shape 里面,核心实现就是 infer_size_dimvector 函数:

// pytorch/aten/src/ATen/ExpandUtils.cpp
template <typename Container>
Container infer_size_impl(IntArrayRef a, IntArrayRef b) {
  size_t dimsA = a.size();
  size_t dimsB = b.size();
  size_t ndim = dimsA > dimsB ? dimsA : dimsB;
  Container expandedSizes(ndim);

  // Use ptrdiff_t to ensure signed comparison.
  for (ptrdiff_t i = (ptrdiff_t)ndim - 1; i >= 0; --i) {
    ptrdiff_t offset = ndim - 1 - i;
    ptrdiff_t dimA = dimsA - 1 - offset;
    ptrdiff_t dimB = dimsB - 1 - offset;
    int64_t sizeA = (dimA >= 0) ? a[dimA] : 1;
    int64_t sizeB = (dimB >= 0) ? b[dimB] : 1;

    TORCH_CHECK(
        sizeA == sizeB || sizeA == 1 || sizeB == 1,
        "The size of tensor a (", sizeA,
        ") must match the size of tensor b (", sizeB,
        ") at non-singleton dimension ", i);

      // 1s map to the other size (even 0).
      expandedSizes[i] = sizeA == 1 ? sizeB : sizeA;
  }

  return expandedSizes;
}

DimVector infer_size_dimvector(IntArrayRef a, IntArrayRef b) {
  return infer_size_impl<DimVector>(a, b);
}

type promotion 就是当输入数据类型不相同时,需要把数据类型向表示范围更大的靠拢,例如 float + double ,结果就需要是 double 类型的。这个功能的相关实现在 compute_types 里面。这个函数很长,这里感兴趣的读者可以自行前往 pytorch/aten/src/ATen/TensorIterator.cpp 阅读对应实现,这里实际需要去计算结果其实 PyTorch 是采用查表的方式:

// pytorch/build/lib.linux-x86_64-3.7/torch/include/c10/core/ScalarType.h

static constexpr ScalarType _promoteTypesLookup[static_cast<int>(
      ScalarType::NumOptions)][static_cast<int>(ScalarType::NumOptions)] = {
      /*        u1  i1  i2  i4  i8  f2  f4  f8  c2  c4  c8  b1  q1  q2  q3  bf*/
      /* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, ud, c4, c8, u1, ud, ud, ud, bf},
      /* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, ud, c4, c8, i1, ud, ud, ud, bf},
      /* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, ud, c4, c8, i2, ud, ud, ud, bf},
      /* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, ud, c4, c8, i4, ud, ud, ud, bf},
      /* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, ud, c4, c8, i8, ud, ud, ud, bf},
      /* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, ud, c4, c8, f2, ud, ud, ud, f4},
      /* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, ud, c4, c8, f4, ud, ud, ud, f4},
      /* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, ud, c8, c8, f8, ud, ud, ud, f8},
      /* c2 */ {ud, ud, ud, ud, ud, ud, ud, ud, c2, c4, c8, ud, ud, ud, ud, ud},
      /* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c4, c8, c4, ud, ud, ud, c4},
      /* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, ud, ud, ud, c8},
      /* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, ud, c4, c8, b1, ud, ud, ud, bf},
      /* q1 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
      /* q2 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
      /* q3 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
      /* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, ud, c4, c8, bf, ud, ud, ud, bf},
  };
  return _promoteTypesLookup[static_cast<int>(a)][static_cast<int>(b)];

TensorIterator

最后就是应用类 TensorIterator 了。

struct TORCH_API TensorIterator final : public TensorIteratorBase {
  TensorIterator() : TensorIteratorBase() {}
  // Slicing is OK, TensorIterator guaranteed NOT to have any fields
  TensorIterator(const TensorIteratorBase& iter) : TensorIteratorBase(iter) {}

#define TORCH_DISALLOW_TEMPORARIES(methodname) TORCH_DISALLOW_TEMPORARIES_IMPL(methodname, static)

  static TensorIterator binary_float_op(Tensor& out, const Tensor& a, const Tensor& b);
  static TensorIterator binary_op(Tensor& out, const Tensor& a, const Tensor& b);
  static TensorIterator   (const Tensor& out, const Tensor& a, const Tensor& b);
  TORCH_DISALLOW_TEMPORARIES(borrowing_binary_op)
  static TensorIterator comparison_op(Tensor& out, const Tensor& a, const Tensor& b);
  static TensorIterator unary_op(Tensor& out, const Tensor& a);
  static TensorIterator unary_float_op(Tensor& out, const Tensor& a);
  static TensorIterator nullary_op(Tensor& out);
  static TensorIterator borrowing_nullary_op(const Tensor& out);
  static TensorIterator borrowing_nullary_op(Tensor&& out) = delete;
  static TensorIterator reduce_op(Tensor& out, const Tensor& a);
  static TensorIterator reduce_op(Tensor& out1, Tensor& out2, const Tensor& a);
#undef TORCH_DISALLOW_TEMPORARIES
#undef TORCH_DISALLOW_TEMPORARIES_IMPL

  const Tensor& maybe_get_output(int64_t output_idx) override;
  void set_output(int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, TensorOptions options, DimnameList names) override;
};

可以看到这里面主要是定义了 TensorIterator 可以处理的算子类型,因为同类型的算子都是有一定的共性才会归为一类,所以就可以提前根据不同算子特性提前设置好对应的配置信息,简化流程。以二值类算子为例,其主要处理两个输入的 pointwise 算子:

#define BINARY_OP_CONFIG()                              \
  TensorIteratorConfig()                                \
    .set_check_mem_overlap(true)                        \
    .allow_cpu_scalars(true)                            \
    .promote_inputs_to_common_dtype(true)               \
    .cast_common_dtype_to_outputs(true)                 \
    .enforce_safe_casting_to_output(true)               \

void TensorIteratorBase::build_binary_op(const Tensor& out, const Tensor& a, const Tensor& b) {
  build(BINARY_OP_CONFIG()
      .add_owned_output(out)
      .add_owned_input(a)
      .add_owned_input(b));
}

TensorIterator TensorIterator::binary_op(Tensor& out, const Tensor& a, const Tensor& b) {
  TensorIterator iter;
  iter.build_binary_op(out, a, b);
  return iter;
}

下面就可以将算子输入信息创建好的 TensorIterator 对象传入对应的算子实现 kernel 来执行计算了。

// 按位取并算子
// pytorch/aten/src/ATen/native/BinaryOps.cpp
Tensor& bitwise_and_out(
    const Tensor& self,
    const Tensor& other,
    Tensor& result) {
  auto iter = TensorIterator::binary_op(result, self, other);
  bitwise_and_stub(iter.device_type(), iter);
  return result;
}

// bitwise_and cpu kernel
// pytorch/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
void bitwise_and_kernel(TensorIterator& iter) {
  if (iter.dtype() == ScalarType::Bool) {
    cpu_kernel(
        iter,
        [](bool a, bool b) {
          return a && b;
        });
  } else {
    AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "bitwise_and_cpu", [&]() {
      cpu_kernel_vec(
          iter,
          [](scalar_t a, scalar_t b) -> scalar_t {
            return a & b;
          },
          [](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
            return a & b;
          });
    });
  }
}

为了加速计算,这里 cpu kernel 使用了向量化,并行化等策略。当然也可以派发到其他 device 上进行计算。

以上就是大概介绍了 TensorIterator 功能类的作用。

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值