TensorFlow 中的 BatchToSpaceOp

Atrous Convolution 是 DeepLab V1 中提出的卷积运算。不仅能够明确控制在深度卷积神经网络中计算特征响应的分辨率,还可以在不增加参数数量或计算量的情况下,有效地扩大滤波器的视场以纳入更大的上下文。

作者通过在 Caffe 框架中的 im2col 层添加对底层特征图进行稀疏采样的选项来实现。而到了 DeepLab V2 发布时,该操作已加入到 TensorFlow 官方支持中,即 tf.space_to_batchtf.batch_to_space

借助这两个算子将空洞卷积化简为规则卷积,能够使用现有的高度优化的卷积程序:

  • 对输入特征图进行下采样,采样系数等于空洞卷积率 r r r,解交织产生 r 2 r^2 r2 个降分辨率图,每张图带有 r × r r \times r r×r 范围内的一个位移;
  • 对这些中间特征图应用标准卷积;
  • 将它们重新交织到原始图像分辨率上。

实现原理可参考 The discrete wavelet transform: wedding the a trous and Mallat algorithms

tf.nn.separable_conv2dtf.nn.convolutiontf.nn.atrous_conv2dtf.nn.with_space_to_batch 会调用 tf.batch_to_space

atrous_conv2d
convolution
separable_conv2d
with_space_to_batch
space_to_batch
batch_to_space

tf.space_to_batchtf.batch_to_space 共享底层实现。下面对 tf.batch_to_space 进行介绍。

tf.batch_to_space 可以看作由 Reshape–> Permute -> Reshape --> Crop 操作组成。假设输入为一个4维的 N H W C \mathrm{NHWC} NHWC,核心操作等价于一个[1, 3, 0, 2, 4]的 Permute:
d h × d w × N d h d w H × W × C → t r a n s p o s e N d h d w H × d h × W × d w × C \mathrm{d_h\times d_w \times \frac{N}{d_h d_w}H\times W\times C} \xrightarrow{transpose} \mathrm{\frac{N}{d_h d_w}H\times d_h\times W \times d_w\times C} dh×dw×dhdwNH×W×Ctranspose dhdwNH×dh×W×dw×C

BatchToSpaceOp

OpKernelConstruction::GetAttr 根据参数名获得单个数值。
算子仅接受一个参数。
TensorShape 继承自 TensorShapeBase 类。
根据数据类型和形状创建一个 Tensor
block_shape_vec有两个元素,均为block_size_
Tensor::vecTensor 映射为一个一维的 Eigen::Tensor,方便进行赋值。

template <typename Device, typename T>
class BatchToSpaceOp : public OpKernel {
 public:
  explicit BatchToSpaceOp(OpKernelConstruction* context) : OpKernel(context) {
    OP_REQUIRES_OK(context, context->GetAttr("block_size", &block_size_));
    OP_REQUIRES(
        context, block_size_ > 1,
        errors::InvalidArgument("Block size should be > 1: ", block_size_));
    block_shape_ = Tensor(tensorflow::DT_INT64, TensorShape({2}));
    auto block_shape_vec = block_shape_.vec<int64_t>();
    block_shape_vec(0) = block_size_;
    block_shape_vec(1) = block_size_;
  }

BatchToSpaceOp::Compute

BatchToSpaceOp::Compute
BatchToSpaceOpCompute
BatchToSpaceNDOp::Compute

BatchToSpaceOpCompute 是一个内部的静态函数,接受两个 Tensor 作为输入。

  void Compute(OpKernelContext* context) override {
    const Tensor& in0 = context->input(0);
    const Tensor& in1 = context->input(1);
    const int dims = in0.dims();

    // Check on the input dimensions first.
    // The input is presumed to be [batch, height, width, depth]
    static const int kRequiredDims = 4;
    OP_REQUIRES(context, kRequiredDims == dims,
                errors::InvalidArgument("Input rank should be: ", kRequiredDims,
                                        "instead of: ", dims));
    BatchToSpaceOpCompute<Device, T>(context, in0, block_shape_, in1);
  }

 private:
  int block_size_;
  Tensor block_shape_;
};

BatchToSpaceOpCompute

TensorShapeUtils::IsVector 根据唯独判断。
TensorShapeUtils::IsMatrix
orig_block_shape是一个向量。
block_dims需要比输入的维度至少小1。
orig_crops是一个二维矩阵,与block_dims相对应。

template <typename Device, typename T>
static void BatchToSpaceOpCompute(OpKernelContext* context,
                                  const Tensor& orig_input_tensor,
                                  const Tensor& orig_block_shape,
                                  const Tensor& orig_crops) {
  const int input_dims = orig_input_tensor.dims();
  OP_REQUIRES(
      context, TensorShapeUtils::IsVector(orig_block_shape.shape()),
      errors::InvalidArgument("block_shape rank should be 1 instead of ",
                              orig_block_shape.dims()));

  const int block_dims = orig_block_shape.dim_size(0);
  OP_REQUIRES(
      context, orig_input_tensor.dims() >= 1 + block_dims,
      errors::InvalidArgument("input rank should be >= ", 1 + block_dims,
                              " instead of ", orig_input_tensor.dims()));

  OP_REQUIRES(context,
              TensorShapeUtils::IsMatrix(orig_crops.shape()) &&
                  block_dims == orig_crops.dim_size(0) &&
                  2 == orig_crops.dim_size(1),
              errors::InvalidArgument("crops should have shape [", block_dims,
                                      ", 2] instead of ",
                                      orig_crops.shape().DebugString()));

gtl::InlinedVectorabsl::InlinedVector
internal::spacetobatch::SubtleMustCopyFlat 将数据从 Tensor 拷贝到 Vector 中。
block_shape设置初始容量为4,说明一般情况下不会超过4维。

  // To avoid out-of-bounds access in the case that the block_shape and/or
  // crops tensors are concurrently modified, we must copy the values.
  gtl::InlinedVector<int64_t, 4> block_shape;
  gtl::InlinedVector<int64_t, 8> crops;
  internal::spacetobatch::SubtleMustCopyFlat(orig_block_shape, &block_shape);
  internal::spacetobatch::SubtleMustCopyFlat(orig_crops, &crops);

如果前缀维度没有填充且block_shape为1,则可以合并到批处理维度中。
removed_prefix_block_dims为移除的分块前缀维度。

  // Determine the length of the prefix of block dims that can be combined
  // into the batch dimension due to having no padding and block_shape=1.
  int removed_prefix_block_dims = 0;
  for (; removed_prefix_block_dims < block_dims; ++removed_prefix_block_dims) {
    const int dim = removed_prefix_block_dims;
    if (crops[2 * dim] != 0 || crops[2 * dim + 1] != 0 ||
        block_shape[dim] != 1) {
      break;
    }
  }

如果后缀维度没有填充且block_shape为1,则可以合并到深度维度中。
从后向前遍历,removed_suffix_block_dims为需要移除的维度数。

  // Determine the length of the suffix of block dims that can be combined
  // into the depth dimension due to having no padding and block_shape=1.
  int removed_suffix_block_dims = 0;
  for (; removed_suffix_block_dims < block_dims - removed_prefix_block_dims;
       ++removed_suffix_block_dims) {
    const int dim = block_dims - 1 - removed_suffix_block_dims;
    if (crops[2 * dim] != 0 || crops[2 * dim + 1] != 0 ||
        block_shape[dim] != 1) {
      break;
    }
  }

block_shape_product为分块总数。
只对原始的 batch 维度进行分块。

  // Compute the product of the block_shape values.
  int64_t block_shape_product = 1;
  for (int block_dim = 0; block_dim < block_dims; ++block_dim) {
    block_shape_product *= block_shape[block_dim];
  }
  OP_REQUIRES(
      context, block_shape_product > 0,
      errors::InvalidArgument("Product of block sizes must be positive, got ",
                              block_shape_product));

  const int64_t orig_input_batch_size = orig_input_tensor.dim_size(0);
  OP_REQUIRES(
      context, orig_input_batch_size % block_shape_product == 0,
      errors::InvalidArgument("Input batch dimension (", orig_input_batch_size,
                              ") is not divisible by product of block sizes (",
                              block_shape_product, ")"));

internal_block_dims为块的维度。为0则不需要处理。

  const int internal_block_dims =
      block_dims - removed_prefix_block_dims - removed_suffix_block_dims;
  OP_REQUIRES(context, internal_block_dims <= kMaxSpaceToBatchBlockDims,
              errors::InvalidArgument(
                  "Maximum number of non-combined block dimensions is ",
                  internal_block_dims, " but must not exceed ",
                  kMaxSpaceToBatchBlockDims));

  if (internal_block_dims == 0) {
    context->set_output(0, orig_input_tensor);
    return;
  }

TensorShapeBase::AddDimWithStatus 函数向末尾追加一个维度。
input_batch_size是将前缀维度合并后的批处理维度。

  // For the purpose of computing the result, the input will be treated as
  // having this shape, of rank 2 + internal_block_dims.
  TensorShape internal_input_shape;

  // For the purpose of computing the result, the output will be treated as
  // having this shape, of rank 2 + internal_block_dims.
  TensorShape internal_output_shape;

  // The actual output shape exposed to callers.
  TensorShape external_output_shape;

  OP_REQUIRES_OK(context, external_output_shape.AddDimWithStatus(
                              orig_input_batch_size / block_shape_product));

  int64_t input_batch_size = orig_input_batch_size;
  for (int block_dim = 0; block_dim < removed_prefix_block_dims; ++block_dim) {
    const int64_t size = orig_input_tensor.dim_size(block_dim + 1);
    input_batch_size *= size;
    OP_REQUIRES_OK(context, external_output_shape.AddDimWithStatus(size));
  }
  OP_REQUIRES_OK(context,
                 internal_input_shape.AddDimWithStatus(input_batch_size));
  OP_REQUIRES_OK(context, internal_output_shape.AddDimWithStatus(
                              input_batch_size / block_shape_product));

对于中间的块维度,检查截取值。
input_sizeblock_dim + 1取。
将分块平铺后截取得到最终的输出。

  for (int block_dim = removed_prefix_block_dims;
       block_dim < block_dims - removed_suffix_block_dims; ++block_dim) {
    const int64_t crop_start = crops[2 * block_dim],
                  crop_end = crops[2 * block_dim + 1];
    OP_REQUIRES(context, crop_start >= 0 && crop_end >= 0,
                errors::InvalidArgument("Crops must be non-negative"));
    const int64_t input_size = orig_input_tensor.dim_size(block_dim + 1);
    const int64_t block_shape_value = block_shape[block_dim];
    const int64_t cropped_size =
        input_size * block_shape_value - crop_start - crop_end;
    OP_REQUIRES(context, cropped_size >= 0,
                errors::InvalidArgument("cropped_shape[", block_dim, "]=",
                                        cropped_size, " must be non-negative"));
    OP_REQUIRES_OK(context, internal_input_shape.AddDimWithStatus(input_size));
    OP_REQUIRES_OK(context,
                   internal_output_shape.AddDimWithStatus(cropped_size));
    OP_REQUIRES_OK(context,
                   external_output_shape.AddDimWithStatus(cropped_size));
  }

对于后面的维度,计算出合并后的depth,赋值给internal_output_shape

  int64_t depth = 1;
  for (int dim = block_dims - removed_suffix_block_dims + 1; dim < input_dims;
       ++dim) {
    const int64_t size = orig_input_tensor.dim_size(dim);
    OP_REQUIRES_OK(context, external_output_shape.AddDimWithStatus(size));
    depth *= size;
  }
  OP_REQUIRES_OK(context, internal_input_shape.AddDimWithStatus(depth));
  OP_REQUIRES_OK(context, internal_output_shape.AddDimWithStatus(depth));

internal_cropsinternal_block_shape指向真正需要处理的块维度。
TF_SPACETOBATCH_FOR_EACH_NUM_BLOCK_DIMS 将1到4四种情况的变量转为立即数。
SpaceToBatchFunctor

  // Allocate output tensor.
  Tensor* output_tensor = nullptr;
  OP_REQUIRES_OK(context, context->allocate_output(0, external_output_shape,
                                                   &output_tensor));

  const int64_t* internal_crops = &crops[2 * removed_prefix_block_dims];
  const int64_t* internal_block_shape = &block_shape[removed_prefix_block_dims];

  switch (internal_block_dims) {
#define TF_BATCHTOSPACE_BLOCK_DIMS_CASE(NUM_BLOCK_DIMS)                   \
  case NUM_BLOCK_DIMS: {                                                  \
    OP_REQUIRES_OK(                                                       \
        context,                                                          \
        (functor::SpaceToBatchFunctor<Device, T, NUM_BLOCK_DIMS, true>()( \
            context->eigen_device<Device>(),                              \
            output_tensor->shaped<T, NUM_BLOCK_DIMS + 2>(                 \
                internal_output_shape.dim_sizes()),                       \
            internal_block_shape, internal_crops,                         \
            orig_input_tensor.shaped<T, NUM_BLOCK_DIMS + 2>(              \
                internal_input_shape.dim_sizes()))));                     \
  } break;                                                                \
    /**/
    TF_SPACETOBATCH_FOR_EACH_NUM_BLOCK_DIMS(TF_BATCHTOSPACE_BLOCK_DIMS_CASE)
#undef TF_BATCHTOSPACE_BLOCK_DIMS_CASE
  }
}

SpaceToBatchFunctor<CPUDevice, T, NUM_BLOCK_DIMS, B2S>

template <typename T, int NUM_BLOCK_DIMS, bool B2S>
struct SpaceToBatchFunctor<CPUDevice, T, NUM_BLOCK_DIMS, B2S> {
  using SpaceT = typename std::conditional<B2S, T, const T>::type;
  using BatchT = typename std::conditional<B2S, const T, T>::type;
  Status operator()(
      const CPUDevice& d,
      typename TTypes<SpaceT, NUM_BLOCK_DIMS + 2>::Tensor space_tensor,
      const int64_t block_shape_tensor[NUM_BLOCK_DIMS],
      const int64_t paddings_tensor[NUM_BLOCK_DIMS * 2],
      typename TTypes<BatchT, NUM_BLOCK_DIMS + 2>::Tensor batch_tensor) {
    const int64_t batch_tensor_batch = batch_tensor.dimension(0);

    const int64_t space_tensor_batch = space_tensor.dimension(0);

从 Eigen::Tensor 中取出数据。

    // Copy into local array so that the compiler is free to place in a
    // register.
    int64_t pad_start[NUM_BLOCK_DIMS];
    int64_t block_shape[NUM_BLOCK_DIMS];
    int64_t space_tensor_shape[NUM_BLOCK_DIMS],
        batch_tensor_shape[NUM_BLOCK_DIMS];
    for (int block_dim = 0; block_dim < NUM_BLOCK_DIMS; ++block_dim) {
      pad_start[block_dim] = paddings_tensor[block_dim * 2];
      block_shape[block_dim] = block_shape_tensor[block_dim];
      space_tensor_shape[block_dim] = space_tensor.dimension(block_dim + 1);
      batch_tensor_shape[block_dim] = batch_tensor.dimension(block_dim + 1);
    }

space_tensor_stridesbatch_tensor_strides最低维上的跨度为1。
从后向前计算。

    int64_t space_tensor_strides[NUM_BLOCK_DIMS + 2],
        batch_tensor_strides[NUM_BLOCK_DIMS + 2];
    space_tensor_strides[NUM_BLOCK_DIMS + 1] =
        batch_tensor_strides[NUM_BLOCK_DIMS + 1] = 1;
    for (int dim = NUM_BLOCK_DIMS; dim >= 0; --dim) {
      space_tensor_strides[dim] =
          space_tensor_strides[dim + 1] * space_tensor.dimension(dim + 1);
      batch_tensor_strides[dim] =
          batch_tensor_strides[dim + 1] * batch_tensor.dimension(dim + 1);
    }

    // Use non-const pointers for both input and output to simplify template
    // implementation given lack of constexpr if.
    T* space_tensor_ptr = const_cast<T*>(space_tensor.data());
    T* batch_tensor_ptr = const_cast<T*>(batch_tensor.data());

对于每个批量,计算出其对应输出的批量space_tensor_b以及在输出中的分块索引block_index
从低维到高维,计算每个维度上的block_offsets
调用 SpaceToBatchHelper 函数处理单个分块。

    for (int64_t batch_tensor_b = 0; batch_tensor_b < batch_tensor_batch;
         ++batch_tensor_b) {
      const int64_t space_tensor_b = batch_tensor_b % space_tensor_batch;
      int64_t block_index = batch_tensor_b / space_tensor_batch;
      int64_t block_offsets[NUM_BLOCK_DIMS];
      for (int block_dim = NUM_BLOCK_DIMS - 1; block_dim >= 0; --block_dim) {
        // Skip unnecessary remainder operation for block_dim == 0.
        block_offsets[block_dim] =
            block_dim > 0 ? block_index % block_shape[block_dim] : block_index;
        block_index /= block_shape[block_dim];
      }

      // The compiler should inline the nested loops generated by this template.
      SpaceToBatchHelper<NUM_BLOCK_DIMS, B2S>::run(
          space_tensor_ptr + space_tensor_b * space_tensor_strides[0],
          space_tensor_shape, &space_tensor_strides[1], block_shape, pad_start,
          block_offsets, batch_tensor_shape, &batch_tensor_strides[1],
          batch_tensor_ptr + batch_tensor_b * batch_tensor_strides[0]);
    }
    return OkStatus();
  }
};

SpaceToBatchHelper

// Implementation of nested loops for SpaceToBatchOpFunctor.
//
// To simplify template implementation given lack of constexpr if, both the
// input and output pointers are non-const.
template <int N, bool B2S>
struct SpaceToBatchHelper {
  template <typename T>
  static void run(T* space_tensor_ptr, const int64_t* space_tensor_shape,
                  const int64_t* space_tensor_strides,
                  const int64_t* block_shape, const int64_t* pad_start,
                  const int64_t* block_offsets,
                  const int64_t* batch_tensor_shape,
                  const int64_t* batch_tensor_strides, T* batch_tensor_ptr) {

batch_tensor_shape[0]为分块的首维大小。
space_tensor_pos为与batch_tensor_pos相对应的位置。
检查space_tensor_pos是否在[0, space_tensor_shape[0])范围内。如果在,递归调用,依次找到每个维度的对应关系。
space_tensor_ptr + space_tensor_pos * space_tensor_strides[0]为下一维的起始地址。
如果不在,说明是填充位置。如果是 SpaceToBatch 方向,填充0。只在该维度执行一次,无需考虑后续维度。

    for (int64_t batch_tensor_pos = 0; batch_tensor_pos < batch_tensor_shape[0];
         ++batch_tensor_pos) {
      const int64_t space_tensor_pos =
          batch_tensor_pos * block_shape[0] + block_offsets[0] - pad_start[0];
      if (space_tensor_pos >= 0 && space_tensor_pos < space_tensor_shape[0]) {
        SpaceToBatchHelper<N - 1, B2S>::run(
            space_tensor_ptr + space_tensor_pos * space_tensor_strides[0],
            space_tensor_shape + 1, space_tensor_strides + 1, block_shape + 1,
            pad_start + 1, block_offsets + 1, batch_tensor_shape + 1,
            batch_tensor_strides + 1, batch_tensor_ptr);
      } else {
        if (B2S == false) {
          // Copy in padding.
          for (int64_t i = 0; i < batch_tensor_strides[0]; ++i) {
            batch_tensor_ptr[i] = static_cast<T>(0);
          }
        }
      }
      batch_tensor_ptr += batch_tensor_strides[0];
    }
  }
};

对于深度维,逐元素拷贝。

template <bool B2S>
struct SpaceToBatchHelper<0, B2S> {
  template <typename T>
  static void run(T* space_tensor_ptr, const int64_t* space_tensor_shape,
                  const int64_t* space_tensor_strides,
                  const int64_t* block_shape, const int64_t* pad_start,
                  const int64_t* block_offsets,
                  const int64_t* batch_tensor_shape,
                  const int64_t* batch_tensor_strides, T* batch_tensor_ptr) {
    for (int64_t i = 0; i < batch_tensor_strides[-1]; ++i) {
      if (B2S == false) {
        batch_tensor_ptr[i] = space_tensor_ptr[i];
      } else {
        space_tensor_ptr[i] = batch_tensor_ptr[i];
      }
    }
  }
};

SpaceToBatchFunctor<GPUDevice, T, NUM_BLOCK_DIMS, B2S>

template <typename T, int NUM_BLOCK_DIMS, bool B2S>
struct SpaceToBatchFunctor<GPUDevice, T, NUM_BLOCK_DIMS, B2S> {
  using SpaceT = typename std::conditional<B2S, T, const T>::type;
  using BatchT = typename std::conditional<B2S, const T, T>::type;
  Status operator()(
      const GPUDevice& d,
      typename TTypes<SpaceT, NUM_BLOCK_DIMS + 2>::Tensor space_tensor,
      const int64 block_shape[NUM_BLOCK_DIMS],
      const int64 paddings[NUM_BLOCK_DIMS * 2],
      typename TTypes<BatchT, NUM_BLOCK_DIMS + 2>::Tensor batch_tensor) {
    // Kernel execution fails if number of elements is zero.
    if (batch_tensor.size() == 0) {
      return OkStatus();
    }

S2BParameters 结构体定义了形状信息。
block_shape的每个维需要在 int32 的表示范围内。
将数据赋值给args中的成员变量。

    S2BParameters<NUM_BLOCK_DIMS> args;
    args.space_tensor_batch = space_tensor.dimension(0);
    for (int block_dim = 0; block_dim < NUM_BLOCK_DIMS; ++block_dim) {
      if (block_shape[block_dim] > std::numeric_limits<int32>::max()) {
        return errors::InvalidArgument("block_shape value exceeds 2^32-1");
      }
      args.block_shape[block_dim] = block_shape[block_dim];
      if (space_tensor.dimension(block_dim + 1) >
          std::numeric_limits<int32>::max()) {
        return errors::InvalidArgument("space_tensor dimension exceeds 2^32-1");
      }
      args.space_tensor_spatial_shape[block_dim] =
          space_tensor.dimension(block_dim + 1);
      if (paddings[block_dim * 2] > std::numeric_limits<int32>::max()) {
        return errors::InvalidArgument("paddings/crops value exceeds 2^32-1");
      }
      args.pad_start[block_dim] = paddings[block_dim * 2];
    }

total_count为输入元素总数。
GetGpuLaunchConfig 根据元素总数和设备规格返回 GpuLaunchConfig 结构体。
GpuLaunchKernel 启动 S2B 函数,不使用 shared memory。
config.virtual_thread_count作为参数也传进 S2B 函数中。

    int64 total_count = 1;
    for (int dim = 0; dim < NUM_BLOCK_DIMS + 2; ++dim) {
      args.batch_tensor_shape[dim] = batch_tensor.dimension(dim);
      total_count *= args.batch_tensor_shape[dim];
    }
    if (total_count > std::numeric_limits<int32>::max()) {
      return errors::InvalidArgument(
          "number of batch_tensor elements exceeds 2^32-1");
    }
    GpuLaunchConfig config =
        GetGpuLaunchConfig(static_cast<int32>(total_count), d);
    return GpuLaunchKernel(S2B<T, NUM_BLOCK_DIMS, B2S>, config.block_count,
                           config.thread_per_block, 0, d.stream(),
                           config.virtual_thread_count,
                           const_cast<T*>(space_tensor.data()), args,
                           const_cast<T*>(batch_tensor.data()));
  }
};

S2B

GPU_1D_KERNEL_LOOP 封装了 GpuGridRangeX 构成的 for 循环。
每个线程处理一个元素。

// GPU kernel for space-to-batch (if B2S = false) and batch-to-space conversion
// (if B2S = true).
//
// To simplify template implementation given lack of constexpr if, both the
// input and output pointers are non-const.
template <typename T, int NUM_BLOCK_DIMS, bool B2S>
__global__ void S2B(const int32 nthreads, T* __restrict__ space_tensor_ptr,
                    S2BParameters<NUM_BLOCK_DIMS> args,
                    T* __restrict__ batch_tensor_ptr) {
  GPU_1D_KERNEL_LOOP(batch_tensor_idx, nthreads) {

根据batch_tensor_idx从低维到高维,求出其在 batch_tensor 各维度上的索引。

    int32 remaining_batch_tensor_idx = batch_tensor_idx;

    int32 batch_tensor_pos[NUM_BLOCK_DIMS + 2];

    for (int dim = NUM_BLOCK_DIMS + 1; dim >= 1; --dim) {
      batch_tensor_pos[dim] =
          remaining_batch_tensor_idx % args.batch_tensor_shape[dim];
      remaining_batch_tensor_idx /= args.batch_tensor_shape[dim];
    }
    batch_tensor_pos[0] = remaining_batch_tensor_idx;

对于 batch_tensor 而言,block 维度高于 batch 维度。即按照dh*dw*no*hi*wi的排布。
remaining_block_idx为分块索引。
space_tensor_idx为 space_tensor 的一维索引。
space_tensor_stride为 space_tensor 当前的跨度。
space_tensor_batch_pos为 space_tensor 的批量索引。
从低到高遍历每个分块维度,space_tensor_pos为对应的 space_tensor 位置。如果超出了有效范围,则在 batch2space 时填充0。
block_dim=0时不能取余吗?
offset为当前分块维度上的索引。

    int32 remaining_block_idx = batch_tensor_pos[0] / args.space_tensor_batch;
    int32 space_tensor_idx = batch_tensor_pos[NUM_BLOCK_DIMS + 1];
    int32 space_tensor_stride = args.batch_tensor_shape[NUM_BLOCK_DIMS + 1];
    const int32 space_tensor_batch_pos =
        batch_tensor_pos[0] % args.space_tensor_batch;
    for (int block_dim = NUM_BLOCK_DIMS - 1; block_dim >= 0; --block_dim) {
      int32 offset = remaining_block_idx;
      if (block_dim > 0) {
        offset %= args.block_shape[block_dim];
      }
      int32 space_tensor_pos =
          batch_tensor_pos[block_dim + 1] * args.block_shape[block_dim] +
          offset - args.pad_start[block_dim];
      if (space_tensor_pos < 0 ||
          space_tensor_pos >= args.space_tensor_spatial_shape[block_dim]) {
        if (B2S == false) {
          // In the space-to-batch case, write zero padding.
          batch_tensor_ptr[batch_tensor_idx] = static_cast<T>(0);
        }
        break;
      }

累加当前维度的偏移到space_tensor_idx中。
block_dim为0时,已经将所有维度拉伸为一维。
ldg 只读方式加载数据。
remaining_block_idx除以args.block_shape[block_dim],得到上一层的剩余分块索引。

      space_tensor_idx += space_tensor_stride * space_tensor_pos;
      space_tensor_stride *= args.space_tensor_spatial_shape[block_dim];
      if (block_dim == 0) {
        space_tensor_idx += space_tensor_stride * space_tensor_batch_pos;
        if (B2S == false) {
          batch_tensor_ptr[batch_tensor_idx] =
              ldg(space_tensor_ptr + space_tensor_idx);
        } else {
          space_tensor_ptr[space_tensor_idx] =
              ldg(batch_tensor_ptr + batch_tensor_idx);
        }
      }
      remaining_block_idx /= args.block_shape[block_dim];
    }
  }
}

参考资料:

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值