CUDA sample中3_CUDA_Features里包含一些展示 CUDA 各种特性的sample。cudaTensorCoreGemm这个sample展示了如何使用 WMMA (Warp Matrix Multiply Accumulate) API 执行矩阵乘法。
#include <assert.h>
#include <cuda.h>
#include <mma.h>
#include <stdio.h>
// helper functions and utilities to work with CUDA
#include <helper_cuda.h>
#include <helper_functions.h>
// Externally configurable parameters.
#ifndef CPU_DEBUG
// Set this to 1 to verify the correctness of the GPU-computed matrix.
#define CPU_DEBUG 0
#endif
#ifndef SHARED_MEMORY_LIMIT_64K
// Set this to 0 to use more than 64 Kb of shared memory to cache data, to
// improve the performance of the computations on GPU.
// Note that you need a GPU that can have more than 64 Kb of shared memory
// per multiprocessor.
#define SHARED_MEMORY_LIMIT_64K 1
#endif
// GPU configuration.
#define WARP_SIZE 32
// MMA matrix tile dimensions.
#define M 16
#define N 16
#define K 16
#define WMMA_M 16
#define WMMA_N 16
#define WMMA_K 16
// GEMM configuration.
#define M_TILES 256
#define N_TILES 256
#define K_TILES 256
#define M_GLOBAL (M * M_TILES)
#define N_GLOBAL (N * N_TILES)
#define K_GLOBAL (K * K_TILES)
#define C_LAYOUT wmma::mem_row_major
// Implementation constants.
#define WARPS_PER_BLOCK 8
#define THREADS_PER_BLOCK (WARP_SIZE * WARPS_PER_BLOCK)
#if SHARED_MEMORY_LIMIT_64K
// With only 64 Kb shared memory available, we can fit two 8-tile chunks of
// the A and B matrix data, that are 16 * 16 * 8 * 8 * 2 = 32 Kb each
// (i.e. two 8x8 arrays of tiles of 16x16 half-typed elements per CTA).
// But we cannot account the 8 Kb total skew overhead, without which the
// performance would be severely impacted. So we choose to reduce the chunk size
// in half, i.e. the amount of A and B matrix data we cache in shared memory.
// Accordingly, this doubles the number of outer iterations across the global K
// dimension, which only slightly impacts the performance.
#define CHUNK_K 4
#else
#define CHUNK_K 8
#endif
#define CHUNK_LINE_BYTES (CHUNK_K * K * sizeof(half))
#define WARP_COPY_BYTES (WARP_SIZE * sizeof(int4))
#define CHUNK_COPY_LINES_PER_WARP (WARP_COPY_BYTES / CHUNK_LINE_BYTES)
#define CHUNK_COPY_LINE_LANES (WARP_SIZE / CHUNK_COPY_LINES_PER_WARP)
#define BLOCK_ROW_WARPS 2
#define BLOCK_COL_WARPS 4
#define WARP_ROW_TILES 4
#define WARP_COL_TILES 2
#define BLOCK_ROW_TILES (WARP_ROW_TILES * BLOCK_ROW_WARPS)
#define BLOCK_COL_TILES (WARP_COL_TILES * BLOCK_COL_WARPS)
#define GLOBAL_MEM_STRIDE N_GLOBAL
#define SHMEM_STRIDE (N * BLOCK_ROW_TILES)
#define SHMEM_OFFSET (N * WARP_ROW_TILES)
// The macro below is used to shift rows of the A matrix and columns of the B matrix
// in shared memory to minimize possible bank conflicts.
// Before performing the nvcuda::wmma::mma_sync operation, the warp must load the matrix
// data using the nvcuda::wmma::load_matrix_sync operation. Although the memory access pattern
// is not specified for that function, each lane in the warp can read one or multiple matrix
// elements from different matrix rows or columns.
// For shared memory, such access can result in bank conflicts if different rows / columns
// of the matrix map to the same bank. By shifting each row and column by a few bytes, we
// make sure that they map to different banks, thus reducing the number of possible bank
// conflicts.
// The number of 16 two-byte "half" elements is chosen as the minimum possible shift because
// we must keep each row and column 256-bit aligned, as required by nvcuda::wmma::load_matrix_sync.
#define SKEW_HALF 16
#define checkKernelErrors(expr) \
do { \
expr; \
\
cudaError_t __err = cudaGetLastError(); \
if (__err != cudaSuccess) { \
printf("Line %d: '%s' failed: %s\n", __LINE__, #expr, \
cudaGetErrorString(__err)); \
abort(); \
} \
} while (0)
using namespace nvcuda;
__host__ void init_host_matrices(half *a, half *b, float *c) {
for (int i = 0; i < M_GLOBAL; i++) {
for (int j = 0; j < K_GLOBAL; j++) {
a[i * K_GLOBAL + j] = (half)(rand() % 3);
}
}
for (int i = 0; i < N_GLOBAL; i++) {
for (int j = 0; j < K_GLOBAL; j++) {
b[i * K_GLOBAL + j] = (half)(rand() % 3);
}
}
for (int t = 0; t < M_GLOBAL * N_GLOBAL; t++) {
c[t] = static_cast<float>(rand() % 3);
}
}
__global__ void compute_gemm(const half *A, const half *B, const float *C,
float *D, float alpha, float beta) {
extern __shared__ half shmem[][CHUNK_K * K + SKEW_HALF];
// Warp and lane identification.
const unsigned int warpId = threadIdx.x / WARP_SIZE;
const unsigned int laneId = threadIdx.x % WARP_SIZE;
// Offset in shared memory from which the B matrix is stored.
const size_t shmem_idx_b_off = BLOCK_COL_TILES * M;
// This pointer is used to access the C and D matrix tiles this warp computes.
float *shmem_warp_tile_ptr = (float *)&shmem[0][0] +
(warpId / 2) * SHMEM_STRIDE * K * 2 +
(warpId % 2) * SHMEM_OFFSET;
// This pointer is used to stream the C and D matrices block-wide tile to and
// from shared memory.
float *shmem_warp_stream_ptr =
(float *)&shmem[0][0] + warpId * SHMEM_STRIDE * K;
// Adjust the beta scaler, as it'll be multiplied by alpha at the end of
// each tile computation. Technically this is not generally correct (may
// result in a loss of precision). Zero still needs to be specially handled
// though.
beta /= alpha;
// Each CTA slides along the 128 x 128 tiles from the top left corner of the
// matrix to the right and down, and selects the next tile to compute. Once
// there's no such tile, all warps in this CTA exit.
for (unsigned int block_pos = blockIdx.x;; block_pos += gridDim.x) {
const unsigned int block_tile_i =
((block_pos * BLOCK_ROW_TILES) / N_TILES) * (BLOCK_COL_TILES);
const unsigned int block_tile_j = (block_pos * BLOCK_COL_TILES) % N_TILES;
// Stop when there are no more D matrix tiles to compute in this CTA.
if (block_tile_i >= M_TILES) {
break;
}
// This warp's pointer to the C matrix data to copy memory from to shared
// memory.
const size_t gmem_idx =
(block_tile_i + warpId) * M * GLOBAL_MEM_STRIDE + block_tile_j * N;
const float *src_gmem_warp_stream_ptr = &C[gmem_idx];
// Stream multiple C tiles to shared memory.
#pragma unroll
for (int i = 0; i < K; i++) {
typedef int4 copy_t;
*((copy_t *)(shmem_warp_stream_ptr + SHMEM_STRIDE * i) + laneId) =
*((copy_t *)(src_gmem_warp_stream_pt