CUDA编程之CUDA Sample-3_CUDA_Features-cudaTensorCoreGemm

CUDA sample中3_CUDA_Features里包含一些展示 CUDA 各种特性的sample。cudaTensorCoreGemm这个sample展示了如何使用 WMMA (Warp Matrix Multiply Accumulate) API 执行矩阵乘法。

#include <assert.h>
#include <cuda.h>
#include <mma.h>
#include <stdio.h>

// helper functions and utilities to work with CUDA
#include <helper_cuda.h>
#include <helper_functions.h>

// Externally configurable parameters.

#ifndef CPU_DEBUG
// Set this to 1 to verify the correctness of the GPU-computed matrix.
#define CPU_DEBUG 0
#endif

#ifndef SHARED_MEMORY_LIMIT_64K
// Set this to 0 to use more than 64 Kb of shared memory to cache data, to
// improve the performance of the computations on GPU.
// Note that you need a GPU that can have more than 64 Kb of shared memory
// per multiprocessor.
#define SHARED_MEMORY_LIMIT_64K 1
#endif

// GPU configuration.

#define WARP_SIZE 32

// MMA matrix tile dimensions.

#define M 16
#define N 16
#define K 16

#define WMMA_M 16
#define WMMA_N 16
#define WMMA_K 16

// GEMM configuration.

#define M_TILES 256
#define N_TILES 256
#define K_TILES 256

#define M_GLOBAL (M * M_TILES)
#define N_GLOBAL (N * N_TILES)
#define K_GLOBAL (K * K_TILES)

#define C_LAYOUT wmma::mem_row_major

// Implementation constants.

#define WARPS_PER_BLOCK 8
#define THREADS_PER_BLOCK (WARP_SIZE * WARPS_PER_BLOCK)

#if SHARED_MEMORY_LIMIT_64K
// With only 64 Kb shared memory available, we can fit two 8-tile chunks of
// the A and B matrix data, that are 16 * 16 * 8 * 8 * 2 = 32 Kb each
// (i.e. two 8x8 arrays of tiles of 16x16 half-typed elements per CTA).
// But we cannot account the 8 Kb total skew overhead, without which the
// performance would be severely impacted. So we choose to reduce the chunk size
// in half, i.e. the amount of A and B matrix data we cache in shared memory.
// Accordingly, this doubles the number of outer iterations across the global K
// dimension, which only slightly impacts the performance.
#define CHUNK_K 4
#else
#define CHUNK_K 8
#endif

#define CHUNK_LINE_BYTES (CHUNK_K * K * sizeof(half))
#define WARP_COPY_BYTES (WARP_SIZE * sizeof(int4))
#define CHUNK_COPY_LINES_PER_WARP (WARP_COPY_BYTES / CHUNK_LINE_BYTES)
#define CHUNK_COPY_LINE_LANES (WARP_SIZE / CHUNK_COPY_LINES_PER_WARP)

#define BLOCK_ROW_WARPS 2
#define BLOCK_COL_WARPS 4

#define WARP_ROW_TILES 4
#define WARP_COL_TILES 2

#define BLOCK_ROW_TILES (WARP_ROW_TILES * BLOCK_ROW_WARPS)
#define BLOCK_COL_TILES (WARP_COL_TILES * BLOCK_COL_WARPS)

#define GLOBAL_MEM_STRIDE N_GLOBAL

#define SHMEM_STRIDE (N * BLOCK_ROW_TILES)
#define SHMEM_OFFSET (N * WARP_ROW_TILES)

// The macro below is used to shift rows of the A matrix and columns of the B matrix
// in shared memory to minimize possible bank conflicts.
// Before performing the nvcuda::wmma::mma_sync operation, the warp must load the matrix
// data using the nvcuda::wmma::load_matrix_sync operation. Although the memory access pattern
// is not specified for that function, each lane in the warp can read one or multiple matrix
// elements from different matrix rows or columns.
// For shared memory, such access can result in bank conflicts if different rows / columns
// of the matrix map to the same bank. By shifting each row and column by a few bytes, we
// make sure that they map to different banks, thus reducing the number of possible bank
// conflicts.
// The number of 16 two-byte "half" elements is chosen as the minimum possible shift because
// we must keep each row and column 256-bit aligned, as required by nvcuda::wmma::load_matrix_sync.
#define SKEW_HALF 16

#define checkKernelErrors(expr)                             \
  do {                                                      \
    expr;                                                   \
                                                            \
    cudaError_t __err = cudaGetLastError();                 \
    if (__err != cudaSuccess) {                             \
      printf("Line %d: '%s' failed: %s\n", __LINE__, #expr, \
             cudaGetErrorString(__err));                    \
      abort();                                              \
    }                                                       \
  } while (0)

using namespace nvcuda;

__host__ void init_host_matrices(half *a, half *b, float *c) {
  for (int i = 0; i < M_GLOBAL; i++) {
    for (int j = 0; j < K_GLOBAL; j++) {
      a[i * K_GLOBAL + j] = (half)(rand() % 3);
    }
  }

  for (int i = 0; i < N_GLOBAL; i++) {
    for (int j = 0; j < K_GLOBAL; j++) {
      b[i * K_GLOBAL + j] = (half)(rand() % 3);
    }
  }

  for (int t = 0; t < M_GLOBAL * N_GLOBAL; t++) {
    c[t] = static_cast<float>(rand() % 3);
  }
}

__global__ void compute_gemm(const half *A, const half *B, const float *C,
                             float *D, float alpha, float beta) {
  extern __shared__ half shmem[][CHUNK_K * K + SKEW_HALF];

  // Warp and lane identification.
  const unsigned int warpId = threadIdx.x / WARP_SIZE;
  const unsigned int laneId = threadIdx.x % WARP_SIZE;

  // Offset in shared memory from which the B matrix is stored.
  const size_t shmem_idx_b_off = BLOCK_COL_TILES * M;

  // This pointer is used to access the C and D matrix tiles this warp computes.
  float *shmem_warp_tile_ptr = (float *)&shmem[0][0] +
                               (warpId / 2) * SHMEM_STRIDE * K * 2 +
                               (warpId % 2) * SHMEM_OFFSET;

  // This pointer is used to stream the C and D matrices block-wide tile to and
  // from shared memory.
  float *shmem_warp_stream_ptr =
      (float *)&shmem[0][0] + warpId * SHMEM_STRIDE * K;

  // Adjust the beta scaler, as it'll be multiplied by alpha at the end of
  // each tile computation. Technically this is not generally correct (may
  // result in a loss of precision). Zero still needs to be specially handled
  // though.
  beta /= alpha;

  // Each CTA slides along the 128 x 128 tiles from the top left corner of the
  // matrix to the right and down, and selects the next tile to compute. Once
  // there's no such tile, all warps in this CTA exit.
  for (unsigned int block_pos = blockIdx.x;; block_pos += gridDim.x) {
    const unsigned int block_tile_i =
        ((block_pos * BLOCK_ROW_TILES) / N_TILES) * (BLOCK_COL_TILES);
    const unsigned int block_tile_j = (block_pos * BLOCK_COL_TILES) % N_TILES;

    // Stop when there are no more D matrix tiles to compute in this CTA.
    if (block_tile_i >= M_TILES) {
      break;
    }

    // This warp's pointer to the C matrix data to copy memory from to shared
    // memory.
    const size_t gmem_idx =
        (block_tile_i + warpId) * M * GLOBAL_MEM_STRIDE + block_tile_j * N;
    const float *src_gmem_warp_stream_ptr = &C[gmem_idx];

    // Stream multiple C tiles to shared memory.
#pragma unroll
    for (int i = 0; i < K; i++) {
      typedef int4 copy_t;

      *((copy_t *)(shmem_warp_stream_ptr + SHMEM_STRIDE * i) + laneId) =
          *((copy_t *)(src_gmem_warp_stream_ptr + GLOBAL_MEM_STRIDE * i) +
            laneId);
    }

    __syncthreads();

    // These fragments will accumulate the result of A and B matrix fragment
    // multiplications along the K_GLOBAL dimension.
    wmma::fragment<wmma::accumulator, M, N, K, float> c[WARP_COL_TILES]
                                                       [WARP_ROW_TILES];

    // Load the C matrix tiles into fragments from shared memory.
#pragma unroll
    for (int i = 0; i < WARP_COL_TILES; i++) {
#pragma unroll
      for (int j = 0; j < WARP_ROW_TILES; j++) {
        const float *tile_ptr =
            shmem_warp_tile_ptr + i * SHMEM_STRIDE * K + j * N;

        wmma::load_matrix_sync(c[i][j], tile_ptr, SHMEM_STRIDE, C_LAYOUT);
      }
    }

    __syncthreads();

    // Scale the C matrix.
#pragma unroll
    for (int i = 0; i < WARP_COL_TILES; i++) {
#pragma unroll
      for (int j = 0; j < WARP_ROW_TILES; j++) {
#pragma unroll
        for (int t = 0; t < c[i][j].num_elements; t++) {
          c[i][j].x[t] *= beta;
        }
      }
    }

    // Select what warp copies what matrix to shared memory.
    // Warps 0-3 copy the A matrix, warps 4-7 copy the B matrix.
    const half *warp_ptr = (warpId < 4) ? (&A[block_tile_i * M * K_GLOBAL] +
                                           M * K_GLOBAL * (warpId % 4) * 2)
                                        : (&B[block_tile_j * N * K_GLOBAL] +
                                           N * K_GLOBAL * (warpId % 4) * 2);

    // Go through the global K dimension by a fixed step at a time.
#pragma unroll
    for (int tile_k = 0; tile_k < K_TILES; tile_k += CHUNK_K) {
      // Copy slices of the A and B matrices to shared memory.
      // The first half of the warps in the CTA copy the A matrix, the rest copy
      // the B matrix.
      size_t shmem_idx =
          warpId < (WARPS_PER_BLOCK / 2)
              ? (M * (warpId % (WARPS_PER_BLOCK / 2)) * 2)
              : (N * (warpId % (WARPS_PER_BLOCK / 2)) * 2 + shmem_idx_b_off);

      // First half of the warp copies the first row / column of the matrix,
      // the second half of the warp copies the next.
      int4 *lane_ptr = (int4 *)(warp_ptr + tile_k * K +
                                (laneId / CHUNK_COPY_LINE_LANES) * K_GLOBAL) +
                       (laneId % CHUNK_COPY_LINE_LANES);

      // Shift the second half of the warp to the next row / column in the
      // shared memory.
      shmem_idx += laneId / CHUNK_COPY_LINE_LANES;

#pragma unroll
      for (int i = 0; i < ((WARP_SIZE / 2) / CHUNK_COPY_LINES_PER_WARP) * 2;
           i++) {
        // Copy 16 bytes at once in each lane.
        *((int4 *)&shmem[shmem_idx][0] + (laneId % CHUNK_COPY_LINE_LANES)) =
            *lane_ptr;

        // Advance the global memory pointer and the shared memory index.
        lane_ptr =
            (int4 *)((half *)lane_ptr + K_GLOBAL * CHUNK_COPY_LINES_PER_WARP);
        shmem_idx += CHUNK_COPY_LINES_PER_WARP;
      }

      __syncthreads();

      // Compute a grid of C matrix tiles in each warp.
#pragma unroll
      for (int k_step = 0; k_step < CHUNK_K; k_step++) {
        wmma::fragment<wmma::matrix_a, M, N, K, half, wmma::row_major>
            a[WARP_COL_TILES];
        wmma::fragment<wmma::matrix_b, M, N, K, half, wmma::col_major>
            b[WARP_ROW_TILES];

#pragma unroll
        for (int i = 0; i < WARP_COL_TILES; i++) {
          size_t shmem_idx_a = (warpId / 2) * M * 2 + (i * M);
          const half *tile_ptr = &shmem[shmem_idx_a][k_step * K];

          wmma::load_matrix_sync(a[i], tile_ptr, K * CHUNK_K + SKEW_HALF);

#pragma unroll
          for (int j = 0; j < WARP_ROW_TILES; j++) {
            if (i == 0) {
              // Load the B matrix fragment once, because it is going to be
              // reused against the other A matrix fragments.
              size_t shmem_idx_b = shmem_idx_b_off +
                                   (WARP_ROW_TILES * N) * (warpId % 2) +
                                   (j * N);
              const half *tile_ptr = &shmem[shmem_idx_b][k_step * K];

              wmma::load_matrix_sync(b[j], tile_ptr, K * CHUNK_K + SKEW_HALF);
            }

            wmma::mma_sync(c[i][j], a[i], b[j], c[i][j]);
          }
        }
      }

      __syncthreads();
    }

      // Store the D fragments to shared memory.
#pragma unroll
    for (int i = 0; i < WARP_COL_TILES; i++) {
#pragma unroll
      for (int j = 0; j < WARP_ROW_TILES; j++) {
#pragma unroll
        // Uniform, point-wise transformations of ALL fragment elements by ALL
        // threads in the warp are well-defined even though element indices
        // within fragment storage are not defined.
        for (int t = 0; t < c[i][j].num_elements; t++) c[i][j].x[t] *= alpha;

        float *tile_ptr = shmem_warp_tile_ptr + i * SHMEM_STRIDE * K + j * N;

        wmma::store_matrix_sync(tile_ptr, c[i][j], SHMEM_STRIDE, C_LAYOUT);
      }
    }

    __syncthreads();

    // Now that shared memory contains all the D tiles, stream them to global
    // memory.
    float *dst_gmem_warp_stream_ptr = &D[gmem_idx];

#pragma unroll
    for (int i = 0; i < K; i++) {
      *((int4 *)(dst_gmem_warp_stream_ptr + GLOBAL_MEM_STRIDE * i) + laneId) =
          *((int4 *)(shmem_warp_stream_ptr + SHMEM_STRIDE * i) + laneId);
    }

    __syncthreads();
  }
}

// Performs an MxNxK GEMM (C=alpha*A*B + beta*C) assuming:
//  1) Matrices are packed in memory.
//  2) M, N and K are multiples of 16.
//  3) Neither A nor B are transposed.
// Note: This is a less performant version of the compute_gemm kernel. It is
// designed for
//       demonstration purposes only to show the CUDA WMMA API use without
//       relying on availability of the shared memory.
__global__ void simple_wmma_gemm(half *a, half *b, float *c, float *d, int m_ld,
                                 int n_ld, int k_ld, float alpha, float beta) {
  // Leading dimensions. Packed with no transpositions.
  int lda = k_ld;
  int ldb = k_ld;
  int ldc = n_ld;

  // Tile using a 2D grid
  int warpM = (blockIdx.x * blockDim.x + threadIdx.x) / warpSize;
  int warpN = (blockIdx.y * blockDim.y + threadIdx.y);

  // Declare the fragments
  wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major>
      a_frag;
  wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major>
      b_frag;
  wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> acc_frag;
  wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> c_frag;

  wmma::fill_fragment(acc_frag, 0.0f);

  // Loop over k
  for (int i = 0; i < k_ld; i += WMMA_K) {
    int aCol = i;
    int aRow = warpM * WMMA_M;
    int bCol = warpN * N;
    int bRow = i;

    // Bounds checking
    if (aRow < m_ld && aCol < k_ld && bRow < k_ld && bCol < n_ld) {
      // Load the inputs
      wmma::load_matrix_sync(a_frag, a + aCol + aRow * lda, lda);
      wmma::load_matrix_sync(b_frag, b + bRow + bCol * ldb, ldb);

      // Perform the matrix multiplication
      wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag);
    }
  }

  // Load in the current value of c, scale it by beta, and add this our result
  // scaled by alpha
  int cCol = warpN * WMMA_N;
  int cRow = warpM * WMMA_M;

  if (cRow < m_ld && cCol < n_ld) {
    wmma::load_matrix_sync(c_frag, c + cCol + cRow * ldc, ldc,
                           wmma::mem_row_major);

    for (int i = 0; i < c_frag.num_elements; i++) {
      c_frag.x[i] = alpha * acc_frag.x[i] + beta * c_frag.x[i];
    }

    // Store the output
    wmma::store_matrix_sync(d + cCol + cRow * ldc, c_frag, ldc,
                            wmma::mem_row_major);
  }
}

__host__ void matMultiplyOnHost(half *A, half *B, float *C, float alpha,
                                float beta, int numARows, int numAColumns,
                                int numBRows, int numBColumns, int numCRows,
                                int numCColumns) {
  for (int i = 0; i < numCRows; i++) {
    for (int j = 0; j < numCColumns; j++) {
      float temp = 0.0;

      for (int k = 0; k < numAColumns; k++) {
        temp += (float)A[i * numAColumns + k] * (float)B[j * numBRows + k];
      }

      C[i * numCColumns + j] = temp * alpha + beta * C[i * numCColumns + j];
    }
  }
}

int main(int argc, char **argv) {
  printf("Initializing...\n");

  int dev = findCudaDevice(argc, (const char **)argv);

  cudaDeviceProp deviceProp;
  checkCudaErrors(cudaGetDeviceProperties(&deviceProp, dev));

  // Tensor cores require a GPU of Volta (SM7X) architecture or higher.
  if (deviceProp.major < 7) {
    printf(
        "cudaTensorCoreGemm requires SM 7.0 or higher to use Tensor "
        "Cores.  Exiting...\n");
    exit(EXIT_WAIVED);
  }

  printf("M: %d (%d x %d)\n", M_GLOBAL, M, M_TILES);
  printf("N: %d (%d x %d)\n", N_GLOBAL, N, N_TILES);
  printf("K: %d (%d x %d)\n", K_GLOBAL, K, K_TILES);

  half *A_h = NULL;
  half *B_h = NULL;
  float *C_h = NULL;
#if CPU_DEBUG
  float *result_hD = NULL;
  float *result_host = NULL;
#endif

  A_h = (half *)malloc(sizeof(half) * M_GLOBAL * K_GLOBAL);
  B_h = (half *)malloc(sizeof(half) * K_GLOBAL * N_GLOBAL);
  C_h = (float *)malloc(sizeof(float) * M_GLOBAL * N_GLOBAL);
#if CPU_DEBUG
  result_hD = (float *)malloc(sizeof(float) * M_GLOBAL * N_GLOBAL);
  result_host = (float *)malloc(sizeof(float) * M_GLOBAL * N_GLOBAL);
#endif

  half *A = NULL;
  half *B = NULL;
  float *C = NULL;
  float *D = NULL;

  checkCudaErrors(cudaMalloc(reinterpret_cast<void **>(&A),
                             sizeof(half) * M_GLOBAL * K_GLOBAL));
  checkCudaErrors(cudaMalloc(reinterpret_cast<void **>(&B),
                             sizeof(half) * N_GLOBAL * K_GLOBAL));
  checkCudaErrors(cudaMalloc(reinterpret_cast<void **>(&C),
                             sizeof(float) * M_GLOBAL * N_GLOBAL));
  checkCudaErrors(cudaMalloc(reinterpret_cast<void **>(&D),
                             sizeof(float) * M_GLOBAL * N_GLOBAL));

  assert(((unsigned long long)A) % 128 == 0);
  assert(((unsigned long long)B) % 128 == 0);
  assert(((unsigned long long)C) % 128 == 0);
  assert(((unsigned long long)D) % 128 == 0);

  init_host_matrices(A_h, B_h, C_h);

  printf("Preparing data for GPU...\n");

  checkCudaErrors(cudaMemcpy(A, A_h, sizeof(half) * M_GLOBAL * K_GLOBAL,
                             cudaMemcpyHostToDevice));
  checkCudaErrors(cudaMemcpy(B, B_h, sizeof(half) * N_GLOBAL * K_GLOBAL,
                             cudaMemcpyHostToDevice));
  checkCudaErrors(cudaMemcpy(C, C_h, sizeof(float) * M_GLOBAL * N_GLOBAL,
                             cudaMemcpyHostToDevice));
  checkCudaErrors(cudaMemset(D, 0, sizeof(float) * M_GLOBAL * N_GLOBAL));

  enum {
    // Compute the right amount of shared memory to request.
    // We need shared memory to hold per-CTA C and D matrix tiles, and to cache
    // per-CTA chunks
    // of the A and B matrices. Therefore, the right amount to request is the
    // maximum of those
    // two numbers.
    SHMEM_SZ = MAX(
        sizeof(half) * (BLOCK_COL_TILES * M) * (CHUNK_K * K + SKEW_HALF) * 2,
        M * (BLOCK_ROW_WARPS * WARP_ROW_TILES) * N *
            (BLOCK_COL_WARPS * WARP_COL_TILES) * sizeof(float))
  };

  printf("Required shared memory size: %lu Kb\n", SHMEM_SZ / 1024UL);

  const float alpha = 1.1f;
  const float beta = 1.2f;

  cudaEvent_t start, stop;

  checkCudaErrors(cudaEventCreate(&start));
  checkCudaErrors(cudaEventCreate(&stop));
  checkCudaErrors(cudaEventRecord(start));

  // If enough shared memory available on the GPU use high performant kernel
  if (deviceProp.sharedMemPerMultiprocessor >= SHMEM_SZ) {
    printf("Computing... using high performance kernel compute_gemm \n");

    checkCudaErrors(cudaFuncSetAttribute(
        compute_gemm, cudaFuncAttributeMaxDynamicSharedMemorySize, SHMEM_SZ));
    checkKernelErrors(
        (compute_gemm<<<deviceProp.multiProcessorCount, THREADS_PER_BLOCK,
                        SHMEM_SZ>>>(A, B, C, D, alpha, beta)));
#if CPU_DEBUG
    checkCudaErrors(cudaMemcpy(result_hD, D,
                               sizeof(float) * M_GLOBAL * N_GLOBAL,
                               cudaMemcpyDeviceToHost));
#endif
  } else {
    dim3 gridDim;
    dim3 blockDim;

    // blockDim.x must be a multple of warpSize
    // 128x4 means we have 16 warps and a block computes a 64x64 output tile
    blockDim.x = 128;
    blockDim.y = 4;

    gridDim.x = (M_GLOBAL + (WMMA_M * blockDim.x / 32 - 1)) /
                (WMMA_M * blockDim.x / 32);
    gridDim.y = (N_GLOBAL + WMMA_N * blockDim.y - 1) / (WMMA_N * blockDim.y);

    printf("Computing... using simple_wmma_gemm kernel\n");
    simple_wmma_gemm<<<gridDim, blockDim>>>(A, B, C, D, M_GLOBAL, N_GLOBAL,
                                            K_GLOBAL, alpha, beta);
#if CPU_DEBUG
    checkCudaErrors(cudaMemcpy(result_hD, D,
                               sizeof(float) * M_GLOBAL * N_GLOBAL,
                               cudaMemcpyDeviceToHost));
#endif
  }

  checkCudaErrors(cudaEventRecord(stop));
  checkCudaErrors(cudaEventSynchronize(stop));

#if CPU_DEBUG
  printf("Verifying correctness of the computations...\n");

  memcpy(result_host, C_h, sizeof(float) * M_GLOBAL * N_GLOBAL);

  matMultiplyOnHost(A_h, B_h, result_host, alpha, beta, M_GLOBAL, K_GLOBAL,
                    K_GLOBAL, N_GLOBAL, M_GLOBAL, N_GLOBAL);

  for (int i = 0; i < N_GLOBAL * M_GLOBAL; i++) {
    if (fabs(result_hD[i] - result_host[i]) > 0.1f)
      printf("mismatch i=%d result_hD=%f result_host=%f\n", i, result_hD[i],
             result_host[i]);
  }
  free(result_hD);
  free(result_host);
#endif

  float milliseconds = 0;

  checkCudaErrors(cudaEventElapsedTime(&milliseconds, start, stop));

  printf("Time: %f ms\n", milliseconds);
  printf("TFLOPS: %.2f\n", static_cast<double>((static_cast<double>(M_GLOBAL) *
                                                N_GLOBAL * K_GLOBAL * 2) /
                                               (milliseconds / 1000.)) /
                               1e12);

  free(A_h);
  free(B_h);
  free(C_h);
  checkCudaErrors(cudaFree(reinterpret_cast<void *>(A)));
  checkCudaErrors(cudaFree(reinterpret_cast<void *>(B)));
  checkCudaErrors(cudaFree(reinterpret_cast<void *>(C)));
  checkCudaErrors(cudaFree(reinterpret_cast<void *>(D)));

  return 0;
}

TensorCore

第一代——Volta

  • 混合精度矩阵乘法 - FP16 和 FP32
  • 可达到高达 12 倍的峰值 TFLOPS 性能

该技术结合了 FP16(半精度浮点数)和 FP32(单精度浮点数)的优势,可以在保持精度的同时,大幅提升矩阵乘法运算的性能。

主要特点如下:

  1. 使用 FP16 进行大部分的数据存储和计算,利用 FP16 的高带宽和低功耗优势。
  2. 关键的累加运算使用 FP32,保证最终结果的精度。
  3. 充分利用 GPU 硬件加速单元,如 Tensor Cores,实现矩阵乘法的高性能计算。
  4. 可以达到高达 12 倍于传统 FP32 矩阵乘法的峰值 TFLOPS 性能。

这种混合精度的矩阵乘法技术广泛应用于深度学习、图形渲染、科学计算等领域,是一种提升计算性能的有效方法。它充分利用了 GPU 硬件的并行计算能力,同时保证了所需的数值精度。

NVIDIA-Tensor-Core

TensorCore硬件单元

下图是1/4个SM硬件单元:

NVIDIA V100 Tensor Cores 是可编程的矩阵乘累加单元,可为训练和推理应用程序提供高达 125 Tensor TFLOPS 的性能。V100 GPU 包含 640 个 Tensor Cores,每个 SM 有 8 个。Tensor Cores 及其相关数据通路是专门设计的,可显著提高浮点计算吞吐量,同时只需要很小的面积和功耗。广泛使用时钟门控技术来最大化功耗节省。

混合精度矩阵运算 - 4x4 矩阵

每个 Tensor Core 提供一个 4x4x4 矩阵处理阵列,执行 D = A * B + C 运算,其中 A、B、C 和 D 都是 4x4 矩阵(图 1)。矩阵乘法输入 A 和 B 使用 FP16 格式,而累加矩阵 C 和 D 可以是 FP16 或 FP32 格式。

这种技术利用 FP16(半精度浮点数)和 FP32(单精度浮点数)的优势,实现高效的 4x4 矩阵运算。其主要特点如下:

  1. 输入矩阵采用 FP16 格式存储,利用 FP16 的高带宽和低功耗优势。
  2. 计算过程中使用 FP32 进行关键的累加运算,以保证最终结果的精度。
  3. 充分利用 GPU 硬件的加速单元,如 Tensor Cores,来实现高性能的矩阵乘法计算。
  4. 与纯 FP32 矩阵运算相比,可以获得 2-4 倍的性能提升。

这种混合精度矩阵运算技术广泛应用于以下场景:

  • 深度学习模型的训练和推理
  • 图形渲染和几何变换
  • 科学计算和工程模拟

它能够在保持所需精度的同时,显著提升矩阵运算的计算性能。这种方法充分利用了现代 GPU 硬件的并行计算能力,是一种非常有效的性能优化技术。

每个 Tensor Core 每个时钟周期可执行 64 个混合精度的浮点 FMA(Fused Multiply-Add)运算,其中使用 FP16 格式进行输入乘法,并产生全精度的乘积,然后使用 FP32 格式进行累加(下图)。一个 SM 中的 8 个 Tensor Cores 总共可以每个时钟周期执行 1024 个浮点运算。

与使用标准 FP32 运算的 Pascal GP100 相比,这种方法为深度学习应用程序每个 SM 带来了 8 倍的吞吐量提升。与 Pascal P100 GPU 相比,Volta V100 GPU 的总体吞吐量提升达到了 12 倍。Tensor Cores 使用 FP16 输入数据,并进行 FP32 累加。FP16 乘法得到的全精度结果将与其他点积结果在 4x4x4 矩阵乘法中进行 FP32 累加(上图)。

在程序执行期间,多个 Tensor Cores 会被一个完整的 warp 并发使用。Warp 内的线程提供了一个更大的 16x16x16 矩阵运算,供 Tensor Cores 处理。CUDA 在 CUDA C++ WMMA API 中公开了这些操作作为 warp 级矩阵操作。这些 C++ 接口提供了专门的矩阵加载、矩阵乘累加和矩阵存储操作,以高效地在 CUDA C++ 程序中使用 Tensor Cores。

WMMA

Tensor Cores 支持一种称为 warp 矩阵乘法累加(wmma)的操作,提供了针对 FP16 (hmma) 和整数 (imma) 矩阵乘法的优化路径。

WMMA API 包含在 mma.h 头文件中。完整的命名空间是 nvcuda::wmma::*,但在代码中保留 wmma 是很有用的,所以我们只使用 nvcuda 命名空间。

#include <mma.h>
using namespace nvcuda;

完整的 GEMM 规范允许该算法对 a 或 b 进行转置操作,并且数据步长可以大于矩阵中的步长。为了简单起见,假设 a 和 b 都没有被转置,并且内存和矩阵的主要维度是相同的。

要采用的策略是让单个 warp 负责输出矩阵的单个 16×16 部分。通过使用 2D 网格和线程块,可以有效地将 warp 平铺在 2D 输出矩阵上。

// The only dimensions currently supported by WMMA
const int WMMA_M = 16;
const int WMMA_N = 16;
const int WMMA_K = 16;
 
__global__ void wmma_example(half *a, half *b, float *c, 
                             int M, int N, int K, 
                             float alpha, float beta) 
{
 
    // Leading dimensions. Packed with no transpositions.
    int lda = M;
    int ldb = K;
    int ldc = M;
     
    // Tile using a 2D grid
    int warpM = (blockIdx.x * blockDim.x + threadIdx.x) / warpSize;
    int warpN = (blockIdx.y * blockDim.y + threadIdx.y);

在执行 MMA 操作之前,操作数矩阵必须以寄存器的形式表示在 GPU 上。由于 MMA 是一个 warp 级别的操作,这些寄存器会分布在 warp 中的各个线程中,每个线程持有整个矩阵的一个片段。

TensorCore的编程有四个步骤:

1. 创建Fragment

在做运算之前,需要把数据从内存里读取到tensorCore里,所以需要先加载数据。

在 CUDA 中,fragment是一个带有template参数的template类型,这些参数描述了以下内容:

  • fragment包含的矩阵(A、B 或累加器)
  • 整个 WMMA 操作的形状
  • 数据类型
  • 对于 A 和 B 矩阵,数据是行主序还是列主序

最后一个参数可用于对 A 或 B 矩阵进行转置。这个示例没有进行转置,所以两个矩阵都是列主序,这是 GEMM 的标准。

// Declare the fragments
    wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major> a_frag;
    wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major> b_frag;
    wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> acc_frag;
    wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> c_frag;

这段代码中声明了四个 WMMA 片段,这些片段表示了 GEMM 操作中涉及的三个矩阵:输入矩阵 A 和 B、以及输出矩阵 C。它们使用 WMMA API 中的特殊格式来利用 Tensor Cores 进行高效的矩阵乘法计算。

2. 初始化Fragments

初始化步骤的最后一部分是将累加器片段填充为零。

wmma::fill_fragment(acc_frag, 0.0f);

GEMM 的策略是让每个 warp 计算输出矩阵的一个 tile。为此,需要循环遍历 A 矩阵的行和 B 矩阵的列。这沿着两个矩阵的 K 维度,产生一个 M x N 的输出 tile。

加载矩阵的函数从内存(本例中为全局内存,尽管也可以是任何其他内存空间)读取数据,并将其放入一个片段中。

加载函数的第三个参数是矩阵在内存中的主维度。由于要加载的 16x16 tile b 在内存中是不连续的,所以该函数需要知道连续列或行之间的步长,如果这些是行主序片段的话。

MMA 调用是就地累积的,所以第一个和最后一个参数都是之前初始化为零的累加器片段。

// Loop over the K-dimension
for (int i = 0; i < K; i += WMMA_K) {
    int aRow = warpM * WMMA_M;
    int aCol = i;
    int bRow = i;
    int bCol = warpN * WMMA_N;
     
    // Bounds checking
    if (aRow < M && aCol < K && bRow < K && bCol < N) {
        // Load the inputs
        wmma::load_matrix_sync(a_frag, a + aRow + aCol * lda, lda);
        wmma::load_matrix_sync(b_frag, b + bRow + bCol * ldb, ldb);
 
    }
}
  1. 该代码实现了一个循环,遍历矩阵 A 和矩阵 B 的 K 维度。循环步长为 WMMA_K。

  2. 在每次迭代中,计算了 A 矩阵的行索引 aRow 和 B 矩阵的列索引 bCol。这些索引是根据当前 warp 的位置 (warpM, warpN) 和 WMMA 的尺寸 (WMMA_M, WMMA_N) 计算出来的。

  3. 接下来进行边界检查,确保索引在矩阵的有效范围内。

  4. 如果索引在有效范围内,则使用 wmma::load_matrix_sync() 函数从内存中加载 A 矩阵和 B 矩阵的对应数据块到 a_frag 和 b_frag 中。这里需要提供内存地址偏移和每行的主维度长度 (lda, ldb)。

  5. 最后,使用 wmma::mma_sync() 函数执行矩阵乘法运算,将结果累积到 acc_frag 中。

这段代码实现了通过 warp 级并行计算来完成 GEMM (General Matrix Multiplication) 操作的策略。它遍历 A 和 B 矩阵的 K 维度,加载相应的数据块,并执行矩阵乘法运算,最终累积到输出矩阵中。

acc_frag 现在保存了基于 A 和 B 矩阵乘法计算得到的这个 warp 的输出块的结果。完整的 GEMM 规范允许对这个结果进行缩放操作,并累加到现有的矩阵上。

"fragment"在这段代码中指的是 CUDA 中的 WMMA (Warp Matrix Multiply-Accumulate) 操作的一部分。具体来说:

  1. a_frag, b_frag 和 acc_frag 是在 WMMA 计算过程中使用的数据结构,用于存储和操作矩阵数据。

  2. "fragment"表示从矩阵中取出的一个小块或子矩阵,这些子矩阵会被并行地加载和计算。

  3. WMMA 操作会将大型的矩阵乘法问题分解为多个小的矩阵乘法子问题,每个子问题都由一个 warp 来并行计算。

  4. a_frag、b_frag 和 acc_frag 就是这些子矩阵的数据结构,用于在 warp 内部存储和操作这些小的矩阵块。

  5. 通过对这些 fragment 进行加载、计算和累加,最终可以完成整个大型矩阵的乘法运算。

所以 fragment 在这里代表的是 WMMA 操作中使用的小型矩阵块,是并行计算过程的基本单元。它们的尺寸和数量都是根据 WMMA 的配置参数 (WMMA_M, WMMA_N, WMMA_K) 来决定的。

执行这种缩放操作的一种方法是对分片(fragment)进行逐元素操作。尽管线程和矩阵坐标的对应关系并未定义,但逐元素操作并不需要知道这一对应关系,因此仍然可以在分片上执行。

因此,只要两个分片具有相同的模板参数,就可以对它们执行缩放操作或将一个分片的内容累加到另一个分片上。如果分片的模板参数不同,结果是未定义的。利用这个特性,可以先加载 C 矩阵中现有的数据,然后使用正确的缩放因子,将当前计算的结果累加到该数据上。

3. 执行矩阵操作

        // Perform the matrix multiplication
        wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag);
// Load in current value of c, scale by beta, and add to result scaled by alpha
int cRow = warpM * WMMA_M;
int cCol = warpN * WMMA_N;
 
if (cRow < M && cCol < N) {
    wmma::load_matrix_sync(c_frag, c + cRow + cCol * ldc, ldc, wmma::mem_col_major);
     
    for(int i=0; i < c_frag.num_elements; i++) {
        c_frag.x[i] = alpha * acc_frag.x[i] + beta * c_frag.x[i];
    }
  1. 首先计算 C 矩阵当前块的行索引 cRow 和列索引 cCol,同样是根据当前 warp 的位置和 WMMA 的尺寸计算得到。

  2. 进行边界检查,确保索引在 C 矩阵的有效范围内。

  3. 使用 wmma::load_matrix_sync() 函数从内存中加载 C 矩阵当前块的数据到 c_frag 分片中。这里指定了数据的存储格式为列主序。

  4. 然后,遍历 c_frag 分片中的所有元素,对每个元素执行以下操作:

    • 将当前结果 acc_frag 乘以缩放因子 alpha
    • 将当前 C 矩阵的值乘以缩放因子 beta
    • 将以上两者相加的结果存回 c_frag

这段代码实现了将计算结果 acc_frag 与 C 矩阵当前块的值相结合的操作,并根据指定的缩放因子 alpha 和 beta 进行了相应的缩放。最终,更新后的 C 矩阵块的值将被写回内存。

4. 存储结果

最后,将数据存储到内存中。同样地,目标指针可以是 GPU 可见的任何内存空间,并且必须指定内存中的主维度长度。还可以指定输出是以行主序还是列主序的方式写入。

        // Store the output
        wmma::store_matrix_sync(c + cRow + cCol * ldc, c_frag, ldc, wmma::mem_col_major);
    }
}

其他sample也使用了该技术,只是支持的数据类型不同:

Kernel解读

__global__ void simple_wmma_gemm(half *a, half *b, float *c, float *d, int m_ld,
                                 int n_ld, int k_ld, float alpha, float beta) {
  // Leading dimensions. Packed with no transpositions.
  int lda = k_ld;
  int ldb = k_ld;
  int ldc = n_ld;

  // Tile using a 2D grid
  int warpM = (blockIdx.x * blockDim.x + threadIdx.x) / warpSize;
  int warpN = (blockIdx.y * blockDim.y + threadIdx.y);

  // Declare the fragments
  wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major>
      a_frag;
  wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major>
      b_frag;
  wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> acc_frag;
  wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> c_frag;

  wmma::fill_fragment(acc_frag, 0.0f);

  // Loop over k
  for (int i = 0; i < k_ld; i += WMMA_K) {
    int aCol = i;
    int aRow = warpM * WMMA_M;
    int bCol = warpN * N;
    int bRow = i;

    // Bounds checking
    if (aRow < m_ld && aCol < k_ld && bRow < k_ld && bCol < n_ld) {
      // Load the inputs
      wmma::load_matrix_sync(a_frag, a + aCol + aRow * lda, lda);
      wmma::load_matrix_sync(b_frag, b + bRow + bCol * ldb, ldb);

      // Perform the matrix multiplication
      wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag);
    }
  }

  // Load in the current value of c, scale it by beta, and add this our result
  // scaled by alpha
  int cCol = warpN * WMMA_N;
  int cRow = warpM * WMMA_M;

  if (cRow < m_ld && cCol < n_ld) {
    wmma::load_matrix_sync(c_frag, c + cCol + cRow * ldc, ldc,
                           wmma::mem_row_major);

    for (int i = 0; i < c_frag.num_elements; i++) {
      c_frag.x[i] = alpha * acc_frag.x[i] + beta * c_frag.x[i];
    }

    // Store the output
    wmma::store_matrix_sync(d + cCol + cRow * ldc, c_frag, ldc,
                            wmma::mem_row_major);
  }
}

这段代码实现了一个基于 CUDA WMMA (Warp Matrix Multiply-Accumulate) API 的矩阵乘法内核函数。它执行一个 M x N x K 的矩阵乘法运算 (C = alpha * A * B + beta * C)。具体逻辑如下:

  1. 假设:

    • 矩阵 A、B 和 C 都是内存中连续存储的。
    • M、N 和 K 都是 16 的倍数。
    • A 和 B 矩阵都没有进行转置。
  2. 计算 leading dimensions (lda、ldb、ldc)。这些值表示矩阵在内存中的行间距。

  3. 根据当前线程在 2D 网格中的位置,计算当前 warp 处理的矩阵块的行索引 warpM 和列索引 warpN。

  4. 声明 WMMA 操作所需的各种 fragment 数据结构:

    • a_fragb_frag: 用于存储 A 和 B 矩阵的子块
    • acc_frag: 用于存储和累积部分矩阵乘法结果
    • c_frag: 用于存储 C 矩阵的当前子块
  5. 将 acc_frag 初始化为 0。

  6. 循环遍历 K 维度,对每个 K 块执行以下操作:

    • 计算当前 A 和 B 子块的行列索引 (aRow、aCol、bRow、bCol)
    • 进行边界检查,确保索引在矩阵范围内
    • 使用 wmma::load_matrix_sync() 从内存中加载 A 和 B 子块到 a_frag 和 b_frag
    • 使用 wmma::mma_sync() 执行矩阵乘法,结果累加到 acc_frag
  7. 计算当前 C 矩阵子块的行列索引 (cRow、cCol)。

  8. 再次进行边界检查,确保 C 子块索引在矩阵范围内。

  9. 使用 wmma::load_matrix_sync() 从内存中加载当前 C 子块到 c_frag

  10. 遍历 c_frag 中的每个元素,将计算结果 (acc_frag) 乘以 alpha,将 C 子块的值乘以 beta,并将两者相加的结果存回 c_frag

  11. 最后,使用 wmma::store_matrix_sync() 将更新后的 c_frag 子块写回内存。

这个内核函数利用 CUDA WMMA API 以高效的方式执行了一个矩阵乘法运算,并将结果更新到输出矩阵 C 中。它是一个示例代码,用于展示如何在没有共享内存的情况下使用 WMMA 进行计算。

compute_gemm内核函数的实现逻辑与simple_wmma_gemm类似,区别是它使用了shared memory,内存分配方式不同。

运行结果:

Initializing...
GPU Device 0: "Ada" with compute capability 8.9

M: 4096 (16 x 256)
N: 4096 (16 x 256)
K: 4096 (16 x 256)
Preparing data for GPU...
Required shared memory size: 64 Kb
Computing... using high performance kernel compute_gemm
Time: 165.487610 ms
TFLOPS: 0.83

  • 17
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值