使用二条PTX MMA m16n8k16指令实现一个16x16x16的GEMM,并跟wmma对比差异

使用二条PTX MMA m16n8k16指令实现一个16x16x16的GEMM,并跟wmma对比差异

本文演示,如何使用二条PTX MMA m16n8k16指令实现一个16x16x16的GEMM,并跟wmma对比差异

1.运行numpy的矩阵乘

tee numpy_gemm.py<<-'EOF'
def print_data(data):
    r,c=data.shape
    for i in range(r):
        for j in range(c):
            print(f"{data[i][j]:8.3f}",end=",")
        print("\n",end="")
    print("\n",end="")

import numpy as np
M=16
N=16
K=16
input_a=np.arange(M*K,dtype=np.float32).reshape(M,K)*0.01
input_a=input_a.astype(np.float16)
input_b=np.arange(K*N,dtype=np.float32).reshape(K,N)*0.01
input_b=input_b.astype(np.float16)
output_d=np.dot(input_a,input_b)
print_data(output_d)
EOF
python numpy_gemm.py

输出

   1.984,   1.996,   2.008,   2.020,   2.031,   2.043,   2.057,   2.068,   2.080,   2.092,   2.104,   2.115,   2.127,   2.141,   2.152,   2.164,
   5.055,   5.094,   5.133,   5.168,   5.207,   5.242,   5.281,   5.320,   5.355,   5.395,   5.434,   5.469,   5.508,   5.547,   5.582,   5.621,
   8.125,   8.188,   8.250,   8.312,   8.383,   8.445,   8.508,   8.570,   8.633,   8.695,   8.758,   8.820,   8.883,   8.945,   9.016,   9.078,
  11.203,  11.289,  11.375,  11.461,  11.555,  11.641,  11.734,  11.820,  11.906,  12.000,  12.086,  12.172,  12.266,  12.352,  12.445,  12.531,
  14.273,  14.383,  14.500,  14.617,  14.727,  14.844,  14.961,  15.070,  15.188,  15.305,  15.414,  15.531,  15.641,  15.758,  15.875,  15.992,
  17.344,  17.484,  17.625,  17.766,  17.906,  18.047,  18.188,  18.328,  18.469,  18.609,  18.750,  18.891,  19.016,  19.172,  19.312,  19.453,
  20.422,  20.578,  20.750,  20.906,  21.078,  21.234,  21.406,  21.578,  21.734,  21.906,  22.062,  22.234,  22.406,  22.562,  22.734,  22.906,
  23.484,  23.672,  23.875,  24.062,  24.250,  24.438,  24.641,  24.828,  25.016,  25.203,  25.391,  25.594,  25.781,  25.969,  26.172,  26.359,
  26.562,  26.781,  27.000,  27.203,  27.422,  27.641,  27.859,  28.078,  28.297,  28.516,  28.719,  28.938,  29.156,  29.375,  29.594,  29.812,
  29.625,  29.875,  30.109,  30.359,  30.594,  30.844,  31.094,  31.328,  31.578,  31.812,  32.062,  32.281,  32.531,  32.781,  33.031,  33.281,
  32.719,  32.969,  33.250,  33.500,  33.781,  34.031,  34.312,  34.594,  34.844,  35.125,  35.375,  35.656,  35.906,  36.188,  36.469,  36.719,
  35.781,  36.062,  36.375,  36.656,  36.938,  37.250,  37.531,  37.844,  38.125,  38.406,  38.719,  39.000,  39.281,  39.594,  39.875,  40.188,
  38.844,  39.156,  39.500,  39.812,  40.125,  40.438,  40.781,  41.094,  41.406,  41.719,  42.031,  42.344,  42.688,  43.000,  43.312,  43.625,
  41.938,  42.281,  42.625,  42.969,  43.312,  43.656,  44.000,  44.344,  44.688,  45.031,  45.375,  45.719,  46.062,  46.406,  46.750,  47.094,
  45.000,  45.375,  45.719,  46.094,  46.469,  46.844,  47.219,  47.594,  47.969,  48.312,  48.688,  49.062,  49.438,  49.812,  50.188,  50.562,
  48.062,  48.469,  48.844,  49.250,  49.656,  50.031,  50.438,  50.844,  51.219,  51.625,  52.031,  52.406,  52.812,  53.219,  53.594,  54.000,

2.cuda kernel测试

tee mma_ops.cu<<-'EOF'
#include <iostream>
#include <cuda_runtime.h>
#include <stdio.h>
#include <cuda.h>
#include <mma.h>

using namespace nvcuda;

#define WARP_SIZE 32

#define CHECK_CUDA(status)                                              \
  {                                                                     \
    cudaError_t error = status;                                         \
    if (error != cudaSuccess) {                                         \
      std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \
                << " at line: " << __LINE__ << std::endl;               \
      exit(EXIT_FAILURE);                                               \
    }                                                                   \
  }

//mma指令
#define HMMA16816(RD0, RD1, RA0, RA1, RA2, RA3, RB0, RB1, RC0, RC1)                                                    \
    asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\n" \
                 : "=r"(RD0), "=r"(RD1)                                                                                \
                 : "r"(RA0), "r"(RA1), "r"(RA2), "r"(RA3), "r"(RB0), "r"(RB1), "r"(RC0), "r"(RC1))

//加载A矩阵(行存储)
#define LDMATRIX_X4(R0, R1, R2, R3, addr)                                             \
    asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" \
                 : "=r"(R0), "=r"(R1), "=r"(R2), "=r"(R3)                             \
                 : "l"(addr))

#define LDMATRIX_X4_TRANS(R0, R1, R2, R3, addr)                                             \
    asm volatile("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" \
                 : "=r"(R0), "=r"(R1), "=r"(R2), "=r"(R3)                             \
                 : "l"(addr))

//加载B矩阵(行存储),需要转置
#define LDMATRIX_X2(R0, R1, addr) \
    asm volatile("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(R0), "=r"(R1) : "l"(addr))

//异步加载数据
#define CP_ASYNC_CG(dst, src, Bytes) \
    asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(Bytes))
#define CP_ASYNC_COMMIT_GROUP() asm volatile("cp.async.commit_group;\n" ::)
#define CP_ASYNC_WAIT_GROUP(N) asm volatile("cp.async.wait_group %0;\n" ::"n"(N))
#define CP_ASYNC_WAIT_ALL() asm volatile("cp.async.wait_all;\n" ::)

__global__ void ptx_m16n8k16_kernel(half* input_A, half* input_B, half* input_C, int M, int N, int K) {
    
    const size_t laneid = threadIdx.x % WARP_SIZE;
    
    __shared__ half A[16*16];
    __shared__ half B[16*16];      
    //为了保证对比的公平,都先加载到share memory里
    uint32_t a_smem_lane_addr = __cvta_generic_to_shared(&A[laneid*8]); 
    CP_ASYNC_CG(a_smem_lane_addr,&input_A[laneid*8],16);
     
    uint32_t b_smem_lane_addr = __cvta_generic_to_shared(&B[laneid*8]); 
    CP_ASYNC_CG(b_smem_lane_addr,&input_B[laneid*8],16); 

    CP_ASYNC_COMMIT_GROUP();
    CP_ASYNC_WAIT_GROUP(0);
    __syncthreads();
    
    clock_t begin=clock64();
    uint32_t RA[4];
    uint32_t RB[2][2];
    uint32_t RC[2][2];
    
    RC[0][0]=0;
    RC[0][1]=0;
    
    RC[1][0]=0;
    RC[1][1]=0;
    
    int aTile_index = laneid % 16 * 16 + laneid / 16 * 8;
    LDMATRIX_X4(RA[0], RA[1], RA[2], RA[3], __cvta_generic_to_shared(&A[aTile_index]));
    LDMATRIX_X2(RB[0][0], RB[0][1],__cvta_generic_to_shared(&B[laneid % 16 * 16]));
    LDMATRIX_X2(RB[1][0],RB[1][1],__cvta_generic_to_shared(&B[laneid % 16 * 16+8]));
    
    //执行mma执行
    HMMA16816(RC[0][0], RC[0][1],
              RA[0], RA[1], RA[2], RA[3],
              RB[0][0], RB[0][1],
              RC[0][0], RC[0][1]);   
    HMMA16816(RC[1][0], RC[1][1],
              RA[0], RA[1], RA[2], RA[3],
              RB[1][0], RB[1][1],
              RC[1][0], RC[1][1]);
              
    //C矩阵 M*N=16*8
    /*
    groupID           = %laneid >> 2
    threadID_in_group = %laneid % 4

    row =    groupID                                 for ci where i <  2
             groupID + 8                             for ci where i >= 2

    col =  (threadID_in_group * 2) + (i & 0x1)       for ci where i = {0,..,3}
    */

    int groupID           = laneid /4;
    int threadID_in_group = laneid % 4;
    
    int row_c0 = groupID;
    int col_c0 = (threadID_in_group * 2) + (0 & 0x1);
    
    int row_c2 = groupID + 8;
    int col_c2 = (threadID_in_group * 2) + (2 & 0x1);              
              
    //写回到DRAM
    *(uint32_t*)&input_C[row_c0*N+col_c0]=RC[0][0];
    *(uint32_t*)&input_C[row_c2*N+col_c2]=RC[0][1];    
    *(uint32_t*)&input_C[row_c0*N+8+col_c0]=RC[1][0];
    *(uint32_t*)&input_C[row_c2*N+8+col_c2]=RC[1][1];    

    clock_t end=clock64();
    
    if(laneid==0)
    {
        printf("ptx_mma_shared kernel e2e(cycles):%ld\n",end-begin);
    }    
}

__global__ void wmma_api_kernel(half *dev_a, half *dev_b,half *dev_c) {
    int tid  = threadIdx.x + blockIdx.x * blockDim.x;

    __shared__ half A[16*16];
    __shared__ half B[16*16];
    
    //为了保证对比的公平,都先加载到share memory里
    uint32_t a_smem_lane_addr = __cvta_generic_to_shared(&A[threadIdx.x*8]); 
    CP_ASYNC_CG(a_smem_lane_addr,&dev_a[threadIdx.x*8],16);
     
    uint32_t b_smem_lane_addr = __cvta_generic_to_shared(&B[threadIdx.x*8]); 
    CP_ASYNC_CG(b_smem_lane_addr,&dev_b[threadIdx.x*8],16); 

    CP_ASYNC_COMMIT_GROUP();
    CP_ASYNC_WAIT_GROUP(0);
    __syncthreads();
    
    clock_t begin=clock64();
    nvcuda::wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> a_frag;
    nvcuda::wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag;
    nvcuda::wmma::fragment<wmma::accumulator, 16, 16, 16, half> c_frag;    
    wmma::load_matrix_sync(a_frag, A, 16);
    wmma::load_matrix_sync(b_frag, B, 16);
    wmma::fill_fragment(c_frag, 0.0f);
    nvcuda::wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); 
    wmma::store_matrix_sync(dev_c, c_frag, 16, wmma::mem_row_major);
    clock_t end=clock64();
    if(tid==0)
    {
        printf("wmma_kernel e2e(cycles):%ld\n",end-begin);
    } 
}

int M=16;
int N=16;    
int K=16;

void dump(half *host_c)
{
    for(int r=0;r<M;r++)
    {
       for(int c=0;c<N;c++)
       {
        printf("%8.3f ",__half2float(host_c[r*N+c]));
       }
       printf("\n");
    }
}

int main() {
    half *host_a = new half[M*K];
    half *host_b = new half[K*N];
    half *host_c = new half[M*N];
    
    half *dev_a;
    half *dev_b;
    half *dev_c;
    
    CHECK_CUDA(cudaMalloc(&dev_a, sizeof(half)*M*K));
    CHECK_CUDA(cudaMalloc(&dev_b, sizeof(half)*K*N));
    CHECK_CUDA(cudaMalloc(&dev_c, sizeof(half)*M*N));
    
    for(int i = 0; i < M*K; ++i) host_a[i] = __float2half(i*0.01);
    for(int i = 0; i < K*N; ++i) host_b[i] = __float2half(i*0.01);
    
    for(int j=0;j<1;j++)
    {
        CHECK_CUDA(cudaMemcpy(dev_a, host_a, sizeof(half)*M*K,cudaMemcpyHostToDevice));
        CHECK_CUDA(cudaMemcpy(dev_b, host_b, sizeof(half)*K*N,cudaMemcpyHostToDevice));
        for(int i = 0; i < M*N; ++i) host_c[i] = 0;
        CHECK_CUDA(cudaMemcpy(dev_c, host_c, sizeof(half)*K*N,cudaMemcpyHostToDevice));
      
        ptx_m16n8k16_kernel<<<1, 32>>>(dev_a, dev_b,dev_c,M,N,K);cudaDeviceSynchronize();
        cudaMemcpy(host_c, dev_c, sizeof(half)*M*N, cudaMemcpyDeviceToHost);
        dump(host_c);
        
        printf("-------------------------------------------------------------\n");
        for(int i = 0; i < M*N; ++i) host_c[i] = 0;
        CHECK_CUDA(cudaMemcpy(dev_c, host_c, sizeof(half)*K*N,cudaMemcpyHostToDevice));  
        wmma_api_kernel<<<1, 32>>>(dev_a, dev_b,dev_c);cudaDeviceSynchronize();
        cudaMemcpy(host_c, dev_c, sizeof(half)*M*N, cudaMemcpyDeviceToHost);
        dump(host_c);
    }
    
    cudaFree(dev_a);
    cudaFree(dev_b);
    cudaFree(dev_c);
    free(host_a);
    free(host_b);
    free(host_c);
    return 0;
}
EOF

/usr/local/cuda/bin/nvcc -std=c++17 -O2 -arch=sm_86 -lineinfo mma_ops.cu -o mma_ops
./mma_ops

/usr/local/NVIDIA-Nsight-Compute/ncu --set full --section SpeedOfLight_HierarchicalTensorRooflineChart --target-processes all --clock-control=none \
                --print-details all --export ncu_report_mma_diff_ops -f ./mma_ops

# 查看tensor core利用率
/usr/local/NVIDIA-Nsight-Compute/ncu --metrics \
sm__ops_path_tensor_src_fp16_dst_fp16_sparsity_off.sum.pct_of_peak_sustained_elapsed,\
sm__ops_path_tensor_src_fp16_dst_fp16_sparsity_off.sum,\
sm__ops_path_tensor_src_fp16_dst_fp16_sparsity_off.sum.peak_sustained,\
sm__ops_path_tensor_src_fp16_dst_fp16_sparsity_off.sum.per_second,\
sm__pipe_tensor_cycles_active.avg.pct_of_peak_sustained_elapsed,\
sm__pipe_tensor_cycles_active.avg.pct_of_peak_sustained_active,\
sm__cycles_elapsed ./mma_ops

输出

ptx_mma_shared kernel e2e(cycles):144
   1.984    1.996    2.008    2.020    2.031    2.045    2.057    2.068    2.080    2.092    2.104    2.115    2.127    2.141    2.152    2.164
   5.059    5.094    5.133    5.168    5.207    5.242    5.281    5.320    5.355    5.395    5.434    5.469    5.508    5.547    5.582    5.621
   8.125    8.188    8.250    8.312    8.383    8.445    8.508    8.570    8.633    8.695    8.758    8.820    8.883    8.953    9.016    9.078
  11.195   11.289   11.375   11.461   11.555   11.641   11.734   11.820   11.906   12.000   12.086   12.172   12.266   12.352   12.445   12.531
  14.273   14.391   14.500   14.617   14.727   14.844   14.961   15.078   15.188   15.305   15.414   15.531   15.641   15.758   15.875   15.992
  17.344   17.484   17.625   17.766   17.906   18.047   18.188   18.328   18.469   18.609   18.750   18.875   19.016   19.172   19.312   19.438
  20.422   20.578   20.750   20.906   21.078   21.234   21.406   21.578   21.734   21.906   22.078   22.234   22.406   22.562   22.734   22.906
  23.484   23.688   23.875   24.062   24.250   24.438   24.641   24.828   25.016   25.203   25.391   25.594   25.781   25.969   26.172   26.359
  26.562   26.781   27.000   27.203   27.422   27.641   27.859   28.078   28.297   28.516   28.719   28.938   29.156   29.375   29.594   29.812
  29.641   29.875   30.109   30.359   30.594   30.844   31.094   31.328   31.578   31.812   32.062   32.281   32.531   32.781   33.031   33.281
  32.719   32.969   33.250   33.500   33.781   34.031   34.312   34.594   34.844   35.125   35.375   35.656   35.906   36.188   36.469   36.719
  35.781   36.062   36.375   36.656   36.938   37.250   37.531   37.844   38.125   38.406   38.719   39.000   39.281   39.594   39.875   40.188
  38.844   39.156   39.500   39.812   40.125   40.438   40.781   41.094   41.406   41.719   42.031   42.344   42.656   43.000   43.312   43.625
  41.938   42.250   42.625   42.969   43.281   43.656   44.000   44.344   44.688   45.031   45.375   45.719   46.062   46.406   46.750   47.094
  45.000   45.375   45.719   46.094   46.469   46.844   47.219   47.594   47.969   48.312   48.688   49.062   49.438   49.812   50.188   50.562
  48.062   48.469   48.844   49.250   49.656   50.031   50.438   50.844   51.250   51.625   52.031   52.406   52.812   53.219   53.625   54.000
-------------------------------------------------------------
wmma_kernel e2e(cycles):163
   1.984    1.996    2.008    2.020    2.031    2.045    2.057    2.068    2.080    2.092    2.104    2.115    2.127    2.141    2.152    2.164
   5.059    5.094    5.133    5.168    5.207    5.242    5.281    5.320    5.355    5.395    5.434    5.469    5.508    5.547    5.582    5.621
   8.125    8.188    8.250    8.312    8.383    8.445    8.508    8.570    8.633    8.695    8.758    8.820    8.883    8.953    9.016    9.078
  11.195   11.289   11.375   11.461   11.555   11.641   11.734   11.820   11.906   12.000   12.086   12.172   12.266   12.352   12.445   12.531
  14.273   14.391   14.500   14.617   14.727   14.844   14.961   15.078   15.188   15.305   15.414   15.531   15.641   15.758   15.875   15.992
  17.344   17.484   17.625   17.766   17.906   18.047   18.188   18.328   18.469   18.609   18.750   18.875   19.016   19.172   19.312   19.438
  20.422   20.578   20.750   20.906   21.078   21.234   21.406   21.578   21.734   21.906   22.078   22.234   22.406   22.562   22.734   22.906
  23.484   23.688   23.875   24.062   24.250   24.438   24.641   24.828   25.016   25.203   25.391   25.594   25.781   25.969   26.172   26.359
  26.562   26.781   27.000   27.203   27.422   27.641   27.859   28.078   28.297   28.516   28.719   28.938   29.156   29.375   29.594   29.812
  29.641   29.875   30.109   30.359   30.594   30.844   31.094   31.328   31.578   31.812   32.062   32.281   32.531   32.781   33.031   33.281
  32.719   32.969   33.250   33.500   33.781   34.031   34.312   34.594   34.844   35.125   35.375   35.656   35.906   36.188   36.469   36.719
  35.781   36.062   36.375   36.656   36.938   37.250   37.531   37.844   38.125   38.406   38.719   39.000   39.281   39.594   39.875   40.188
  38.844   39.156   39.500   39.812   40.125   40.438   40.781   41.094   41.406   41.719   42.031   42.344   42.656   43.000   43.312   43.625
  41.938   42.250   42.625   42.969   43.281   43.656   44.000   44.344   44.688   45.031   45.375   45.719   46.062   46.406   46.750   47.094
  45.000   45.375   45.719   46.094   46.469   46.844   47.219   47.594   47.969   48.312   48.688   49.062   49.438   49.812   50.188   50.562
  48.062   48.469   48.844   49.250   49.656   50.031   50.438   50.844   51.250   51.625   52.031   52.406   52.812   53.219   53.625   54.000
  • 13
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Hi20240217

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值