CUTLASS 1.3.3中的 Volta884_h884gemm

9 篇文章 0 订阅
7 篇文章 0 订阅

CUTLASS 是 CUDA C++ 模板抽象的集合,用于在 CUDA 内的所有级别和规模上实现高性能矩阵-矩阵乘法 (GEMM) 和相关计算。它采用了类似于 cuBLAS 和 cuDNN 中实现的分层分解和数据移动策略。

CUTLASS 最新版本为3.3,相比1.3.3变动较大。然而重温一下1.3.3仍然是有意义的。因为它更易于理解:

Demystifying Tensor Cores to Optimize Half-Precision Matrix Multiply 中提到 T4 GPU 在引入 Tensor Core 之后,原来重计算瓶颈的 GEMM 也变成了 IO 瓶颈。虽然 V100的带宽是 T4的三倍,然而带宽不足问题同样存在。因此,CUTLASS 对于数据路径进行了如下优化:

  • 全路径128 bit 的访问粒度:LDG.128STS.128LDS.128STD.128
  • 无冲突共享内存排列:转置时无需填充 Shared Memory;
  • Software Pipelining:LDG.128LDS.128HMMA.884.F16.F16三种指令并行,隐藏数据移动。

下面以一个矩阵乘测例为例,介绍 Volta884_h884gemm 的实现。

TEST(Volta884_h884gemm_128x64x32_nt, 520x264x136)

OutputTile即 threadblock tile,该测例下设置为32x64x128。WarpGemmShape为32x64x64,这个是固定值。
run_gemm 初始化 Volta884GemmTraits::ParamsGemmTestbed,调用 Gemm::launch 运行后比对结果。

TEST(Volta884_h884gemm_64x64x32_nt, 520x264x136) {

  typedef cutlass::gemm::Volta884GemmTraits<
    cutlass::MatrixLayout::kColumnMajor,
    cutlass::MatrixLayout::kRowMajor,
    cutlass::Shape<32, 64, 128>,
    cutlass::Shape<32, 64, 64>,
    half,
    half,
    half,
    2
  > GemmTraits;

  run_gemm<GemmTraits>(520, 264, 136);
}

CUTLASS 中 Volta884实现的层次结构如下图所示

Gemm
Volta884GemmTraits
Volta884MultiplyAdd
GemmMainloop
Volta884Multiplicand
IdentityBlockSwizzle
GlobalLoadStreamPair
SharedStreamPair
MMAEpilogue
Volta884EpilogueTraits
PredicatedTileLoadStream
PredicatedTileStoreStream
TileStoreStream
TileLoadStream
mma
TileLoadIterator
Volta884ThreadblockMultiplicandStoreIterator
Volta884WarpMultiplicandLoadIterator
MMAGlobalLoadStream
Copy
MMASharedLoadStream

gemm_kernel_nolb

gemm_kernel_nolb
GemmMainloop::multiply_add

Kernel 函数申请动态 Shared Memory,并传递给 GemmMainloop,然后调用 GemmMainloop::multiply_add 进行计算。

/// GEMM kernel without launch bounds specified
template <typename Gemm_>
__global__ /* __launch_bounds__(Gemm_::kThreads) */
void gemm_kernel_nolb(typename Gemm_::Params params) {

  // Dynamic shared memory base pointer
  extern __shared__ int GemmSharedStorageBase[];

  // Declare pointer to dynamic shared memory.
  typename Gemm_::SharedStorage *shared_storage = 
    reinterpret_cast<typename Gemm_::SharedStorage *>(GemmSharedStorageBase);

  // Construct the GEMM object.
  Gemm_ gemm(params, *shared_storage);

  // Run GEMM.
  gemm.multiply_add();
}

GemmMainloop

GemmMainloop 实现了软流水,如下图所示:
在这里插入图片描述

Shared Memory 和寄存器需要两个缓冲区,通过 SM 上的调度实现三条流水线并行。Global Memory 到 Shared Memory 的加载有同步,而从 Shared Memory 移动到寄存器时不需要同步。由于 Ampere 之前的架构不支持 Global Memory 到 Shared Memory 的直接拷贝,因此整个搬运过程比较复杂。如下图所示,程序中多处调用 Copy::transform 函数生成transformed_fragment。原因应该是为了实现类型转换,但 Volta 只支持 half,也就没有实际作用。

PredicatedTileStoreStreamTileStoreStream::copy
GlobalLoadStream::commit
GlobalLoadStream::copy
Copy::transform
Volta884ThreadblockMultiplicandStoreIterator::store_post_increment
MMASharedLoadStream::copy
MMASharedLoadStream::commit
A/B
TileStoreStream::copy
TileStoreStream::copy
TileStoreStream::commit
AB
D
PredicatedTileLoadStream::copy
TileLoadStream::commit
C
Copy::transform
TileStoreIterator::store_post_increment
source_fragment
transformed_fragment
fetched_fragment
transformed_fragment
Global_Memory
Shared_Memory
fetched
transformed
Volta884MultiplyAdd
accumulators
fetched_fragment
transformed_fragment
LinearScaling
fetched_fragment
transformed_fragment
template <typename Traits_>
struct GemmMainloop {

  //
  // Type definitions
  //

  /// The traits.
  typedef Traits_ Traits;

  /// The GEMM mainloop
  typedef typename Traits::KernelClass KernelClass;

  /// The shared storage.
  typedef typename Traits::SharedStorage SharedStorage;

  /// The scalar for A.
  typedef typename Traits::ScalarA ScalarA;
  /// The scalar for B.
  typedef typename Traits::ScalarB ScalarB;
  /// The scalar in the epilogue.
  typedef typename Traits::Epilogue::Scalar ScalarEpilogue;
  /// The scalar for C.
  typedef typename Traits::Epilogue::ScalarC ScalarC;
  /// The scalar for D.
  typedef typename Traits::Epilogue::ScalarD ScalarD;
  /// The index.
  typedef typename Traits::Index Index;

  /// Define the mainloop iteration size
  typedef typename Traits::MultiplyAdd MultiplyAdd;

  /// The number of threads.
  static int const kThreads = Traits::GemmConfig::kThreads;

AccumulatorsPerWarpGemmConfig::AccumulatorsPerWarpVolta884MultiplyAdd::WarpGemmShape,为32x64x64。
Volta884MultiplyAdd::InstructionShape 为4x32x32。因此,kWarpGemmSteps为8。

  // Number of warp-level multiply-accumulate steps executed by each warp.
  static Index const kWarpGemmSteps =
      Traits::GemmConfig::AccumulatorsPerWarp::kD / MultiplyAdd::InstructionShape::kD;

  /*
  // Make sure we have at least 2 unrolling steps or our pipeling is not going to work.
  static_assert(kWarpGemmSteps >= 2, "The pipelining assumes at least two steps");
  */

  /// Use the params object defined in traits
  typedef typename Traits::Params Params;

  //
  // Data members
  //

  /// The params.
  Params const& params;

  /// SharedStorage object
  SharedStorage& shared_storage;
  //
  // Methods
  //

  /// Ctor.
  CUTLASS_DEVICE GemmMainloop(Params const& params_, SharedStorage& shared_storage_)
      : params(params_), shared_storage(shared_storage_) {}

GemmMainloop::fetch_global

GemmMainloop::fetch_global
GlobalLoadStreamPair::residue
GlobalLoadStreamPair::copy

Volta884GemmTraits::GlobalLoadStreamGlobalLoadStreamPair 类型。
GlobalLoadStreamPair::residue 函数调用两次 MMAGlobalLoadStream::residue,计算在线程块 tile 最后一次加载所需的预测掩码。
GlobalLoadStreamPair::copy 函数调用两次 MMAGlobalLoadStream::copy 从 Global Memory 拷贝矩阵元素到寄存器。后者调用 TileLoadIterator::load_post_increment 函数。

  /// Fetches global stream pair
  template <bool Residue>
  CUTLASS_DEVICE void fetch_global(typename Traits::GlobalLoadStream& global_to_shared_stream,
                                   Index outer_k) {
    // If residue portion and not calculating residue in prolog, update residue predicates now.
    if (Residue) {
      global_to_shared_stream.residue(outer_k);
    }
    global_to_shared_stream.copy();
  }

GemmMainloop::consume_tile

如果kWarpGemmSteps小于等于4,则为kGlobalStreamFirst,先从 Global Memory 加载下一次迭代的数据。

  /// Computes a warp-level GEMM on data held in shared memory
  template <bool Residue, bool LastIteration>
  CUTLASS_DEVICE void consume_tile(typename Traits::GlobalLoadStream& global_to_shared_stream,
                                   typename Traits::SharedStream& shared_load_stream,
                                   typename MultiplyAdd::Accumulators& accumulators,
                                   Index outer_k) {

    // Whether to load global stream before loading shared stream
    const bool kGlobalStreamFirst = (kWarpGemmSteps <= 4);

    // Load data for the next iteration of the main loop (unless it's the last iteration).
    if (kGlobalStreamFirst && !LastIteration) {
      fetch_global<Residue>(global_to_shared_stream, outer_k);
    }

首先从 Shared Memory 加载下一次迭代的输入。拥有双缓冲区。
MMASharedLoadStream::copy 调用 Volta884WarpMultiplicandLoadIterator::load 函数加载数据到寄存器中。
问题是前一步如果没有调用 GemmMainloop::fetch_global,从 Shared Memory 拷贝不会有问题吗?

    CUTLASS_PRAGMA_UNROLL
    for (int step = 0; step < kWarpGemmSteps; ++step) {

      // Trigger the copy from shared memory for the next A/B values.
      shared_load_stream.copy((step + 1) % kWarpGemmSteps);

如果不是kGlobalStreamFirst , 在循环的第一步时调用GemmMainloop::fetch_global 函数加载输入。

      // Load data for the next iteration of the main loop (unless it's the last iteration).
      if (!kGlobalStreamFirst && (step == 0) && !LastIteration) {
        fetch_global<Residue>(global_to_shared_stream, outer_k);
      }

如果是倒数第2步,需要确保数据已经加载到了 Shared Memory。
Volta884GemmTraits::shared_load_fence 根据外部传入的StageCount来确定是否同步线程。
GlobalLoadStreamPair::commit 函数会分别调用两个矩阵的 GlobalLoadStream::commit 拷贝到 Shared Memory。
Volta884GemmTraits::shared_store_fence 同步线程。
MMASharedLoadStream::inc_stage 递增stage_index

      if (step == kWarpGemmSteps - 2) {
          // Make sure the data from shared memory has been entirely consumed.
          Traits::shared_load_fence(true);

          global_to_shared_stream.commit();

          // Make sure the data is in shared memory.
          Traits::shared_store_fence(true);

          // Move to the next stage for the load (if it makes sense).
          shared_load_stream.inc_stage();
      }

MMASharedLoadStream::commit 调用 Copy 进行拷贝。Volta884WarpMultiplicandLoadIterator::FragmentFragment
Volta884MultiplyAdd::multiply_add 完成 Warp Tile 的计算。

      // Make sure the values are available for the current iteration to do the multiply-add.
      shared_load_stream.commit(step);

      // Do the math on the fragments of the current iteration.
      MultiplyAdd multiply_add;
      multiply_add.multiply_add(shared_load_stream.fragment_a(step),
                                shared_load_stream.fragment_b(step),
                                accumulators,
                                accumulators);
    }
  }

GemmMainloop::multiply_add

Created with Raphaël 2.3.0 GemmMainloop::multiply_add IdentityBlockSwizzle::get_threadblock_offset IdentityBlockSwizzle::get_threadblock_bounds IdentityBlockSwizzle::get_batch_id GlobalLoadStreamPair::add_batch_offset GlobalLoadStreamPair::move_to_residue GlobalLoadStreamPair::copy GlobalLoadStreamPair::commit Volta884GemmTraits::shared_store_fence GlobalLoadStreamPair::rollback SharedLoadStream::copy ClearAccumulators::clear GemmMainloop::consume_tile GemmEpilogue::epilogue End

make_Coord_from_shape 根据形状创建一个 Coord 对象。

IdentityBlockSwizzle::get_threadblock_offset 获得当前线程块在输出二维图上的偏移。
Volta884GemmTraits::ClearAccumulatorsClearAccumulators
IdentityBlockSwizzle::get_threadblock_bounds 返回 threadblock 的三维边界。

  /// Do the GEMM.
  CUTLASS_DEVICE void multiply_add() {
    // Swizzle the IDs of the block (to enable better cache behavior).
    typename Traits::BlockSwizzle block_swizzle;
    Coord<3> threadblock_offset =
        block_swizzle.get_threadblock_offset(make_Coord_from_shape<typename Traits::OutputTile>());

    // We may want to use shared memory to clear the registers.
    typedef typename Traits::ClearAccumulators ClearAccumulators;

    // Get the bounds for each thread, it maybe different than problem_size
    Coord<3> bounds = block_swizzle.get_threadblock_bounds(params.problem_size,
                                                        params.partitionK_range);

params.global_to_shared_streamGlobalLoadStreamPair::Params
shared_storage.main_loop.global_to_shared_streamGlobalLoadStreamPair::SharedStorage
shared_storage.main_loop.threadblock_tileGlobalLoadStreamPair::ThreadblockTileStorage,即 ZipTileAllocationZipTileAllocation::reference 返回指向数据的 ZipTensorRef 对象。
global_to_shared_streamVolta884GemmTraits::GlobalLoadStreamGlobalLoadStreamPair
GlobalLoadStreamPair::add_batch_offset 调用 GlobalLoadStreamPair::add_batch_offset GlobalLoadStream::add_batch_offset 函数设置迭代器的 batch 偏移。

    // The streams to read A/B from global memory to shared memory.
    typename Traits::GlobalLoadStream global_to_shared_stream(
        params.global_to_shared_stream,
        shared_storage.main_loop.global_to_shared_stream,
        shared_storage.main_loop.threadblock_tile.reference(),
        bounds,
        threadblock_offset);

    // update A and B pointer offset based on batch_id and batch_stride_offset
    global_to_shared_stream.add_batch_offset(block_swizzle.get_batch_id());

    // Create the accumulator clear.
    ClearAccumulators clear;

GlobalLoadStreamPair::move_to_residue 如果是在序幕中执行余数则调用 MMAGlobalLoadStream::move_to_residue 移动指针,否则直接调用 GlobalLoadStreamPair::residue 函数。
GlobalLoadStreamPair::copy 调用 MMAGlobalLoadStream::copy 函数,后者调用 TileLoadIterator::load_post_increment 加载 A 和 B 矩阵的片段到 Fragment 寄存器。
GlobalLoadStreamPair::commit 调用 MMAGlobalLoadStream::commit 函数,后者调用 Copy.transform 进行拷贝,然后调用
Volta884ThreadblockMultiplicandStoreIterator::store_post_increment 保存到 Shared Memory。
Volta884GemmTraits::shared_store_fence 同步 threadblock 内的线程。
GlobalLoadStreamPair::rollback 调用 MMAGlobalLoadStream::rollback 函数,后者调用 TileLoadIterator::initialize_predicates 初始化预测向量,然后移动偏移。

    // Deal with residue in prolog.
    // global_to_shared_stream.move_to_residue(params.problem_size[0], Traits::OutputTile::kD);
    global_to_shared_stream.move_to_residue(bounds[0], Traits::OutputTile::kD);

    // Fetch the fragments for A and B from global memory.
    global_to_shared_stream.copy();

    // Copy the elements to shared memory (after transformation if needed).
    global_to_shared_stream.commit();

    // Make sure the data is in shared memory.
    Traits::shared_store_fence(false);

    // Rollback to the beginning of the first tile (if residue exists).
    // global_to_shared_stream.rollback(params.problem_size[0] % Traits::OutputTile::kD);
    global_to_shared_stream.rollback(bounds[0] % Traits::OutputTile::kD);

shared_load_streamVolta884GemmTraits::SharedStream 类型,即 SharedStreamPair
SharedStreamPair::copy 调用 MMASharedLoadStream::copy,后者调用 Volta884WarpMultiplicandLoadIterator::load 从 Shared Memory 加载。
accumulatorsVolta884MultiplyAdd::Accumulators 类型,即 Fragment
ClearAccumulators::clear 调用 Fragment::clear 将存储清零。
outer_k是什么?

    // The stream of data from shared memory to fragments.
    typename Traits::SharedStream shared_load_stream(
        params.shared_stream,
        shared_storage.main_loop.threadblock_tile.reference());

    // Trigger the copy from shared memory for the 1st stream.
    shared_load_stream.copy(0);

    // Allocate the accumulators.
    typename MultiplyAdd::Accumulators accumulators;

    // Clear the accumulators.
    clear.clear(accumulators);

    // Initial index
    // Index outer_k = params.problem_size[0] - Traits::OutputTile::kD;
    // problem_size[0] might be bigger than bounds[0]
    Index outer_k = bounds[0] - Traits::OutputTile::kD;

如果在序幕中计算了剩余,则仅最后一次处理余数。
GemmMainloop::consume_tile 计算k = Traits::OutputTile::kD的分块。

    // Check if we are computing residue in prolog or not.
    if (Traits::GemmConfig::kResidueInProlog) {
      // Execute all mainloop iterations but the last one.

      CUTLASS_GEMM_LOOP
      for (; outer_k > 0; outer_k -= Traits::OutputTile::kD) {
        CUTLASS_GEMM_LOOP_HEADER
        consume_tile<false, false>(
            global_to_shared_stream, shared_load_stream, accumulators, outer_k);
      }

      consume_tile<false, true>(
          global_to_shared_stream, shared_load_stream, accumulators, outer_k);

否则,每次迭代都考虑余数。

    } else {
      // When kResidueSeparate = true, execute all mainloop iterations but the last two without any
      // consideration for K-residue or predicate updates. This improves the steady state of some
      // kernels.
      if (Traits::GemmConfig::kResidueSeparate) {

        CUTLASS_GEMM_LOOP
        for (; outer_k > Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) {
          CUTLASS_GEMM_LOOP_HEADER
          consume_tile<false, false>(
              global_to_shared_stream, shared_load_stream, accumulators, outer_k);
        }
      }

      // Execute remaining tiles with K-residue predicate updates enabled.
      CUTLASS_GEMM_LOOP
      for (; outer_k > -Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) {
        CUTLASS_GEMM_LOOP_HEADER
        consume_tile<true, false>(
            global_to_shared_stream, shared_load_stream, accumulators, outer_k);
      }
    }

创建 MMAEpilogue 对象,然后调用 MMAEpilogue::epilogue 函数。

    typedef typename Traits::Epilogue Epilogue;
    Epilogue epilogue(params.epilogue, shared_storage.epilogue, params.problem_size.knm());
    epilogue.epilogue(accumulators, threadblock_offset, block_swizzle.get_batch_id());
  }
};

参考资料:

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值