目录
1 WMMA (Warp-level Matrix Multiply Accumulate) API
1 WMMA (Warp-level Matrix Multiply Accumulate) API
对于计算能力在7.0及以上的CUDA设备,可以使用CUDA C++ API调用Tensor Core,支持形如D = AB + C的混合精度的矩阵乘运算。
template<typename Use, int m, int n, int k, typename T, typename Layout=void> class fragment;
void load_matrix_sync(fragment<...> &a, const T* mptr, unsigned ldm);
void load_matrix_sync(fragment<...> &a, const T* mptr, unsigned ldm, layout_t layout);
void store_matrix_sync(T* mptr, const fragment<...> &a, unsigned ldm, layout_t layout);
void fill_fragment(fragment<...> &a, const T& v);
void mma_sync(fragment<...> &d, const fragment<...> &a, const fragment<...> &b, const fragment<...> &c, bool satf=false);
-
fragment:Tensor Core数据存储类,支持matrix_a、matrix_b和accumulator
-
load_matrix_sync:Tensor Core数据加载API,支持将矩阵数据从global memory或shared memory加载到fragment
-
store_matrix_sync:Tensor Core结果存储API,支持将计算结果从fragment存储到global memory或shared memory
-
fill_fragment:fragment填充API,支持常数值填充
-
mma_sync:Tensor Core矩阵乘计算API,支持D = AB + C或者C = AB + C
2 示例
以m16n16k16为例,实现HGEMM:C = AB,其中矩阵A(M * K,row major)、B(K * N,col major)和C(M * N,row major)的精度均为FP16。首先我们看如何使用CUDA Core写HGEMM naive算法。
2.1 CUDA Core
按照每个线程计算矩阵C中的一个元素来构建naive kernel,首先确定当前线程处理矩阵C的元素坐标,再遍历K并直接从global memory中加载所需A、B矩阵元素到寄存器参与计算,最后将计算结果从寄存器直接写回矩阵C。所有block计算完成之后即可得到矩阵C。这个例子不能说简单,只能说技术含量不高,不过我们只是为了对比。 源码在cuda_hgemm。
__global__ void simtNaiveKernel(const half *__restrict__ A, const half *__restrict__ B, half *__restrict__ C, size_t M,
size_t N, size_t K) {
size_t row = threadIdx.y + blockDim.y * blockIdx.y;
size_t col = threadIdx.x + blockDim.x * blockIdx.x;
if (row >= M && col >= N) {
return;
}
float tmp = 0.0;
#pragma unroll
for (size_t i = 0; i < K; ++i) {
tmp += __half2float(A[row * K + i]) * __half2float(B[i + col * K]);
}
C[row * N + col] = __float2half(tmp);
}
void simtNaive(half *A, half *B, half *C, size_t M, size_t N, size_t K) {
dim3 block(16, 16);
dim3 grid(div_ceil(N, block.x), div_ceil(M, block.y));
simtNaiveKernel<<<grid, block>>>(A, B, C, M, N, K);
}
2.2 Tensor Core
我们再来看如何用WMMA API来构建naive kernel,参考cuda-sample。与CUDA Core naive不同的是,WMMA需要按照每个warp处理一个矩阵C的WMMA_M * WMMA_N大小的tile的思路来构建,因为Tensor Core的计算层级是warp级别,计算的矩阵元素也是二维的。接下来,与CUDA Core naive的处理思路一致,首先确定当前warp处理矩阵C的tile坐标,声明计算tilie所需的fragment,再以WMMA_K为步长遍历K并直接从global memory中加载所需A、B矩阵tile到fragment参与计算,最后将计算结果从fragment直接写回矩阵C。所有block计算完成之后即可得到矩阵C。
值得注意的是,load_matrix_sync和store_matrix_sync都是按stride访问矩阵元素。源码在cuda_hgemm。
#define WMMA_M 16
#define WMMA_N 16
#define WMMA_K 16
#define WARP_SIZE 32
using namespace nvcuda;
__global__ void wmmaNaiveKernel(const half *__restrict__ A, const half *__restrict__ B, half *__restrict__ C, size_t M,
size_t N, size_t K) {
const size_t K_tiles = div_ceil(K, WMMA_K);
const size_t warp_row = blockIdx.y * WMMA_M;
const size_t warp_col = blockIdx.x * WMMA_N;
if (warp_row >= M && warp_col >= N) {
return;
}
wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, half> C_frag;
wmma::fill_fragment(C_frag, 0.0f);
#pragma unroll
for (size_t i = 0; i < K_tiles; ++i) {
wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> A_frag;
wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major> B_frag;
wmma::load_matrix_sync(A_frag, A + warp_row * K + i * WMMA_K, K);
wmma::load_matrix_sync(B_frag, B + i * WMMA_K + warp_col * K, K);
wmma::mma_sync(C_frag, A_frag, B_frag, C_frag);
}
wmma::store_matrix_sync(C + warp_row * N + warp_col, C_frag, N, wmma::mem_row_major);
}
void wmmaNaive(half *A, half *B, half *C, size_t M, size_t N, size_t K) {
dim3 block(WARP_SIZE);
dim3 grid(div_ceil(N, WMMA_N), div_ceil(M, WMMA_M));
wmmaNaiveKernel<<<grid, block>>>(A, B, C, M, N, K);
}
2.3 区别
从上述两个naive kernel的代码来看调用CUDA Core和Tensor Core的区别如下:
-
计算层级:CUDA Core是线程级别,Tensor Core是warp级别
-
计算维度:CUDA Core是一维逐点计算,Tensor Core是二维逐tile计算
-
计算依赖:WMMA调用Tensor Core需要借助数据存储类fragment,CUDA Core不需要借助其他
3 底层代码
我们再对上述WMMA naive kernel做进一步探索,看一下它在RTX A6000(sm_86,CUDA 11.3)上对应的PTX和SASS。
3.1 PTX
dump出对应的PTX代码如下,好像不那么简单了。
.visible .entry _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm(
.param .u64 _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_0,
.param .u64 _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_1,
.param .u64 _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_2,
.param .u64 _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_3,
.param .u64 _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_4,
.param .u64 _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_5
)
{
.reg .pred %p<8>;
.reg .b16 %rs<2>;
.reg .f32 %f<2>;
.reg .b32 %r<44>;
.reg .b64 %rd<36>;
ld.param.u64 %rd14, [_Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_0];
ld.param.u64 %rd15, [_Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_1];
ld.param.u64 %rd16, [_Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_2];
ld.param.u64 %rd19, [_Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_3];
ld.param.u64 %rd17, [_Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_4];
ld.param.u64 %rd18, [_Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_5];
shr.u64 %rd1, %rd18, 4;
mov.u32 %r15, %ctaid.y;
shl.b32 %r16, %r15, 4;
cvt.u64.u32 %rd2, %r16;
mov.u32 %r17, %ctaid.x;
shl.b32 %r18, %r17, 4;
cvt.u64.u32 %rd3, %r18;
setp.ge.u64 %p1, %rd2, %rd19;
setp.ge.u64 %p2, %rd3, %rd17;
and.pred %p3, %p1, %p2;
@%p3 bra $L__BB0_5;
and.b64 %rd4, %rd18, 15;
setp.ne.s64 %p4, %rd4, 0;
mov.f32 %f1, 0f00000000;
{ cvt.rn.f16.f32 %rs1, %f1;}
mov.b32 %r40, {%rs1, %rs1};
selp.b64 %rd20, -1, 0, %p4;
setp.eq.s64 %p5, %rd1, %rd20;
mov.u32 %r41, %r40;
mov.u32 %r42, %r40;
mov.u32 %r43, %r40;
@%p5 bra $L__BB0_4;
mul.lo.s64 %rd21, %rd2, %rd18;
cvt.u32.u64 %r2, %rd18;
selp.u64 %rd22, 1, 0, %p4;
add.s64 %rd35, %rd1, %rd22;
mul.lo.s64 %rd23, %rd3, %rd18;
cvta.to.global.u64 %rd24, %rd15;
shl.b64 %rd25, %rd23, 1;
add.s64 %rd34, %rd24, %rd25;
cvta.to.global.u64 %rd26, %rd14;
shl.b64 %rd27, %rd21, 1;
add.s64 %rd33, %rd26, %rd27;
mov.u32 %r41, %r40;
mov.u32 %r42, %r40;
mov.u32 %r43, %r40;
$L__BB0_3:
wmma.load.a.sync.aligned.row.m16n16k16.global.f16 {%r19, %r20, %r21, %r22, %r23, %r24, %r25, %r26}, [%rd33], %r2;
wmma.load.b.sync.aligned.col.m16n16k16.global.f16 {%r27, %r28, %r29, %r30, %r31, %r32, %r33, %r34}, [%rd34], %r2;
wmma.mma.sync.aligned.row.col.m16n16k16.f16.f16 {%r43, %r42, %r41, %r40}, {%r19, %r20, %r21, %r22, %r23, %r24, %r25, %r26}, {%r27, %r28, %r29, %r30, %r31, %r32, %r33, %r34}, {%r43, %r42, %r41, %r40};
add.s64 %rd34, %rd34, 32;
add.s64 %rd33, %rd33, 32;
add.s64 %rd35, %rd35, -1;
setp.ne.s64 %p7, %rd35, 0;
@%p7 bra $L__BB0_3;
$L__BB0_4:
mul.lo.s64 %rd28, %rd2, %rd17;
add.s64 %rd29, %rd28, %rd3;
cvta.to.global.u64 %rd30, %rd16;
shl.b64 %rd31, %rd29, 1;
add.s64 %rd32, %rd30, %rd31;
cvt.u32.u64 %r35, %rd17;
wmma.store.d.sync.aligned.row.m16n16k16.global.f16 [%rd32], {%r43, %r42, %r41, %r40}, %r35;
$L__BB0_5:
ret;
}
不过我们主要关注WMMA相关的PTX指令,如下所示。可以看到这里正是Nvidia提供的WMMA PTX指令来调用Tensor Core,所以无论是使用WMMA API编程,还是使用WMMA PTX指令编程,底层差别不会太大。
wmma.load.a.sync.aligned.row.m16n16k16.global.f16
wmma.load.b.sync.aligned.col.m16n16k16.global.f16
wmma.mma.sync.aligned.row.col.m16n16k16.f16.f16
wmma.store.d.sync.aligned.row.m16n16k16.global.f16
3.2 SASS
进一步dump出对应的SASS代码,似乎也不简单。
Function : _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm
.headerflags @"EF_CUDA_SM86 EF_CUDA_PTX_SM(EF_CUDA_SM86)"
/*0000*/ IMAD.MOV.U32 R1, RZ, RZ, c[0x0][0x28] ; /* 0x00000a00ff017624 */
/* 0x000fc400078e00ff */
/*0010*/ S2R R32, SR_CTAID.X ; /* 0x0000000000207919 */
/* 0x000e220000002500 */
/*0020*/ IMAD.MOV.U32 R0, RZ, RZ, c[0x0][0x188] ; /* 0x00006200ff007624 */
/* 0x000fc600078e00ff */
/*0030*/ S2R R18, SR_CTAID.Y ; /* 0x0000000000127919 */
/* 0x000e620000002600 */
/*0040*/ IMAD.SHL.U32 R32, R32, 0x10, RZ ; /* 0x0000001020207824 */
/* 0x001fe400078e00ff */
/*0050*/ IMAD.SHL.U32 R18, R18, 0x10, RZ ; /* 0x0000001012127824 */
/* 0x002fc600078e00ff */
/*0060*/ ISETP.GE.U32.AND P0, PT, R32, c[0x0][0x180], PT ; /* 0x0000600020007a0c */
/* 0x000fe40003f06070 */
/*0070*/ ISETP.GE.U32.AND P1, PT, R18, c[0x0][0x178], PT ; /* 0x00005e0012007a0c */
/* 0x000fe40003f26070 */
/*0080*/ ISETP.GE.U32.AND.EX P0, PT, RZ, c[0x0][0x184], PT, P0 ; /* 0x00006100ff007a0c */
/* 0x000fe40003f06100 */
/*0090*/ ISETP.GE.U32.AND.EX P1, PT, RZ, c[0x0][0x17c], PT, P1 ; /* 0x00005f00ff007a0c */
/* 0x000fda0003f26110 */
/*00a0*/ @P0 EXIT P1 ; /* 0x000000000000094d */
/* 0x000fea0000800000 */
/*00b0*/ LOP3.LUT P0, RZ, R0.reuse, 0xf, RZ, 0xc0, !PT ; /* 0x0000000f00ff7812 */
/* 0x040fe2000780c0ff */
/*00c0*/ IMAD.MOV.U32 R3, RZ, RZ, 0x4 ; /* 0x00000004ff037424 */
/* 0x000fe200078e00ff */
/*00d0*/ ULDC.64 UR4, c[0x0][0x118] ; /* 0x0000460000047ab9 */
/* 0x000fe20000000a00 */
/*00e0*/ IMAD.MOV.U32 R5, RZ, RZ, c[0x0][0x18c] ; /* 0x00006300ff057624 */
/* 0x000fe200078e00ff */
/*00f0*/ LOP3.LUT P0, RZ, RZ, c[0x0][0x18c], RZ, 0xc0, P0 ; /* 0x00006300ffff7a12 */
/* 0x000fe2000000c0ff */
/*0100*/ CS2R R20, SRZ ; /* 0x0000000000147805 */
/* 0x000fe2000001ff00 */
/*0110*/ SHF.R.U64 R3, R0, R3, c[0x0][0x18c] ; /* 0x0000630000037619 */
/* 0x000fe20000001203 */
/*0120*/ CS2R R16, SRZ ; /* 0x0000000000107805 */
/* 0x000fe2000001ff00 */
/*0130*/ SEL R2, RZ, 0xffffffff, !P0 ; /* 0xffffffffff027807 */
/* 0x000fe40004000000 */
/*0140*/ SHF.R.U32.HI R5, RZ, 0x4, R5 ; /* 0x00000004ff057819 */
/* 0x000fc40000011605 */
/*0150*/ ISETP.NE.U32.AND P1, PT, R3, R2, PT ; /* 0x000000020300720c */
/* 0x000fc80003f25070 */
/*0160*/ ISETP.NE.AND.EX P1, PT, R5, R2, PT, P1 ; /* 0x000000020500720c */
/* 0x000fda0003f25310 */
/*0170*/ @!P1 BRA 0xbc0 ; /* 0x00000a4000009947 */
/* 0x000fea0003800000 */
/*0180*/ SEL R2, RZ, 0x1, !P0 ; /* 0x00000001ff027807 */
/* 0x000fe20004000000 */
/*0190*/ IMAD.WIDE.U32 R6, R18, c[0x0][0x188], RZ ; /* 0x0000620012067a25 */
/* 0x000fe200078e00ff */
/*01a0*/ CS2R R16, SRZ ; /* 0x0000000000107805 */
/* 0x000fe4000001ff00 */
/*01b0*/ IADD3 R2, P0, R2, R3, RZ ; /* 0x0000000302027210 */
/* 0x000fe20007f1e0ff */
/*01c0*/ IMAD.WIDE.U32 R8, R32, c[0x0][0x188], RZ ; /* 0x0000620020087a25 */
/* 0x000fe200078e00ff */
/*01d0*/ LEA R4, P2, R6, c[0x0][0x160], 0x1 ; /* 0x0000580006047a11 */
/* 0x000fc600078408ff */
/*01e0*/ IMAD.X R3, RZ, RZ, R5, P0 ; /* 0x000000ffff037224 */
/* 0x000fe200000e0605 */
/*01f0*/ ISETP.GT.U32.AND P0, PT, R2, RZ, PT ; /* 0x000000ff0200720c */
/* 0x000fe20003f04070 */
/*0200*/ IMAD R11, R18, c[0x0][0x18c], R7 ; /* 0x00006300120b7a24 */
/* 0x000fe200078e0207 */
/*0210*/ LEA R5, P1, R8, c[0x0][0x168], 0x1 ; /* 0x00005a0008057a11 */
/* 0x000fe200078208ff */
/*0220*/ IMAD R7, R32, c[0x0][0x18c], R9 ; /* 0x0000630020077a24 */
/* 0x000fe200078e0209 */
/*0230*/ ISETP.GT.AND.EX P0, PT, R3, RZ, PT, P0 ; /* 0x000000ff0300720c */
/* 0x000fe20003f04300 */
/*0240*/ IMAD.MOV.U32 R20, RZ, RZ, RZ ; /* 0x000000ffff147224 */
/* 0x000fe200078e00ff */
/*0250*/ LEA.HI.X R6, R6, c[0x0][0x164], R11, 0x1, P2 ; /* 0x0000590006067a11 */
/* 0x000fe400010f0c0b */
/*0260*/ LEA.HI.X R7, R8, c[0x0][0x16c], R7, 0x1, P1 ; /* 0x00005b0008077a11 */
/* 0x000fd200008f0c07 */
/*0270*/ @!P0 BRA 0x9b0 ; /* 0x0000073000008947 */
/* 0x000fea0003800000 */
/*0280*/ ISETP.GT.U32.AND P1, PT, R2, 0x3, PT ; /* 0x000000030200780c */
/* 0x000fe40003f24070 */
/*0290*/ PLOP3.LUT P0, PT, PT, PT, PT, 0x80, 0x0 ; /* 0x000000000000781c */
/* 0x000fe40003f0f070 */
/*02a0*/ ISETP.GT.AND.EX P1, PT, R3, RZ, PT, P1 ; /* 0x000000ff0300720c */
/* 0x000fda0003f24310 */
/*02b0*/ @!P1 BRA 0x6d0 ; /* 0x0000041000009947 */
/* 0x000fea0003800000 */
/*02c0*/ S2R R9, SR_LANEID ; /* 0x0000000000097919 */
/* 0x000e220000000000 */
/*02d0*/ SHF.R.U32.HI R10, RZ, 0x1, R0 ; /* 0x00000001ff0a7819 */
/* 0x000fe40000011600 */
/*02e0*/ PLOP3.LUT P0, PT, PT, PT, PT, 0x8, 0x0 ; /* 0x000000000000781c */
/* 0x000fe40003f0e170 */
/*02f0*/ LOP3.LUT R8, R9, 0x3, RZ, 0xc0, !PT ; /* 0x0000000309087812 */
/* 0x001fe400078ec0ff */
/*0300*/ SHF.R.U32.HI R11, RZ, 0x2, R9 ; /* 0x00000002ff0b7819 */
/* 0x000fe20000011609 */
/*0310*/ IMAD.MOV.U32 R9, RZ, RZ, RZ ; /* 0x000000ffff097224 */
/* 0x000fc800078e00ff */
/*0320*/ IMAD.WIDE.U32 R8, R11, R10, R8 ; /* 0x0000000a0b087225 */
/* 0x000fc800078e0008 */
/*0330*/ IMAD.SHL.U32 R30, R8.reuse, 0x4, RZ ; /* 0x00000004081e7824 */
/* 0x040fe200078e00ff */
/*0340*/ SHF.L.U64.HI R19, R8, 0x2, R9 ; /* 0x0000000208137819 */
/* 0x000fc80000010209 */
/*0350*/ IADD3 R22, P1, R4, R30.reuse, RZ ; /* 0x0000001e04167210 */
/* 0x0a0fe40007f3e0ff */
/*0360*/ IADD3 R38, P2, R5, R30, RZ ; /* 0x0000001e05267210 */
/* 0x000fc60007f5e0ff */
/*0370*/ IMAD.X R23, R6, 0x1, R19.reuse, P1 ; /* 0x0000000106177824 */
/* 0x100fe400008e0613 */
/*0380*/ IMAD.X R39, R7, 0x1, R19, P2 ; /* 0x0000000107277824 */
/* 0x000fe400010e0613 */
/*0390*/ IMAD.WIDE.U32 R36, R0.reuse, 0x10, R22 ; /* 0x0000001000247825 */
/* 0x040fe200078e0016 */
/*03a0*/ LDG.E R8, [R22.64] ; /* 0x0000000416087981 */
/* 0x000ea6000c1e1900 */
/*03b0*/ IMAD.WIDE.U32 R34, R0, 0x10, R38 ; /* 0x0000001000227825 */
/* 0x000fe200078e0026 */
/*03c0*/ LDG.E R12, [R38.64] ; /* 0x00000004260c7981 */
/* 0x000ea8000c1e1900 */
/*03d0*/ LDG.E R13, [R38.64+0x10] ; /* 0x00001004260d7981 */
/* 0x000ea8000c1e1900 */
/*03e0*/ LDG.E R10, [R22.64+0x10] ; /* 0x00001004160a7981 */
/* 0x000ea8000c1e1900 */
/*03f0*/ LDG.E R9, [R36.64] ; /* 0x0000000424097981 */
/* 0x000ea8000c1e1900 */
/*0400*/ LDG.E R11, [R36.64+0x10] ; /* 0x00001004240b7981 */
/* 0x000ea8000c1e1900 */
/*0410*/ LDG.E R24, [R34.64] ; /* 0x0000000422187981 */
/* 0x000ee8000c1e1900 */
/*0420*/ LDG.E R25, [R34.64+0x10] ; /* 0x0000100422197981 */
/* 0x000ee8000c1e1900 */
/*0430*/ LDG.E R14, [R34.64+0x20] ; /* 0x00002004220e7981 */
/* 0x000f28000c1e1900 */
/*0440*/ LDG.E R15, [R34.64+0x30] ; /* 0x00003004220f7981 */
/* 0x000f28000c1e1900 */
/*0450*/ LDG.E R26, [R34.64+0x40] ; /* 0x00004004221a7981 */
/* 0x000f28000c1e1900 */
/*0460*/ LDG.E R27, [R34.64+0x50] ; /* 0x00005004221b7981 */
/* 0x000f28000c1e1900 */
/*0470*/ LDG.E R28, [R38.64+0x60] ; /* 0x00006004261c7981 */
/* 0x000f28000c1e1900 */
/*0480*/ LDG.E R29, [R38.64+0x70] ; /* 0x00007004261d7981 */
/* 0x000f22000c1e1900 */
/*0490*/ HMMA.16816.F16 R12, R8.reuse, R12, R16 ; /* 0x0000000c080c723c */
/* 0x044b660000000810 */
/*04a0*/ LDG.E R16, [R38.64+0x20] ; /* 0x0000200426107981 */
/* 0x020ea8000c1e1900 */
/*04b0*/ LDG.E R17, [R38.64+0x30] ; /* 0x0000300426117981 */
/* 0x000ea2000c1e1900 */
/*04c0*/ HMMA.16816.F16 R24, R8, R24, R20 ; /* 0x000000180818723c */
/* 0x008b660000000814 */
/*04d0*/ LDG.E R8, [R22.64+0x20] ; /* 0x0000200416087981 */
/* 0x020ea8000c1e1900 */
/*04e0*/ LDG.E R10, [R22.64+0x30] ; /* 0x00003004160a7981 */
/* 0x000ea8000c1e1900 */
/*04f0*/ LDG.E R9, [R36.64+0x20] ; /* 0x0000200424097981 */
/* 0x000ea8000c1e1900 */
/*0500*/ LDG.E R11, [R36.64+0x30] ; /* 0x00003004240b7981 */
/* 0x000ea8000c1e1900 */
/*0510*/ LDG.E R20, [R38.64+0x40] ; /* 0x0000400426147981 */
/* 0x000ee8000c1e1900 */
/*0520*/ LDG.E R21, [R38.64+0x50] ; /* 0x0000500426157981 */
/* 0x000ee2000c1e1900 */
/*0530*/ HMMA.16816.F16 R16, R8.reuse, R16, R12 ; /* 0x000000100810723c */
/* 0x044b66000000080c */
/*0540*/ LDG.E R12, [R22.64+0x60] ; /* 0x00006004160c7981 */
/* 0x0200a8000c1e1900 */
/*0550*/ LDG.E R13, [R36.64+0x60] ; /* 0x00006004240d7981 */
/* 0x000ea2000c1e1900 */
/*0560*/ HMMA.16816.F16 R24, R8, R14, R24 ; /* 0x0000000e0818723c */
/* 0x010b660000000818 */
/*0570*/ LDG.E R8, [R22.64+0x40] ; /* 0x0000400416087981 */
/* 0x0200e8000c1e1900 */
/*0580*/ LDG.E R10, [R22.64+0x50] ; /* 0x00005004160a7981 */
/* 0x0000e8000c1e1900 */
/*0590*/ LDG.E R9, [R36.64+0x40] ; /* 0x0000400424097981 */
/* 0x000ee8000c1e1900 */
/*05a0*/ LDG.E R11, [R36.64+0x50] ; /* 0x00005004240b7981 */
/* 0x000ee8000c1e1900 */
/*05b0*/ LDG.E R14, [R22.64+0x70] ; /* 0x00007004160e7981 */
/* 0x0000a8000c1e1900 */
/*05c0*/ LDG.E R15, [R36.64+0x70] ; /* 0x00007004240f7981 */
/* 0x000ea8000c1e1900 */
/*05d0*/ LDG.E R22, [R34.64+0x60] ; /* 0x0000600422167981 */
/* 0x001f28000c1e1900 */
/*05e0*/ LDG.E R23, [R34.64+0x70] ; /* 0x0000700422177981 */
/* 0x000f22000c1e1900 */
/*05f0*/ IADD3 R2, P1, R2, -0x4, RZ ; /* 0xfffffffc02027810 */
/* 0x000fc80007f3e0ff */
/*0600*/ IADD3.X R3, R3, -0x1, RZ, P1, !PT ; /* 0xffffffff03037810 */
/* 0x000fe40000ffe4ff */
/*0610*/ ISETP.GT.U32.AND P1, PT, R2, 0x3, PT ; /* 0x000000030200780c */
/* 0x000fc80003f24070 */
/*0620*/ ISETP.GT.AND.EX P1, PT, R3, RZ, PT, P1 ; /* 0x000000ff0300720c */
/* 0x000fe40003f24310 */
/*0630*/ IADD3 R5, P2, R5, 0x80, RZ ; /* 0x0000008005057810 */
/* 0x000fe40007f5e0ff */
/*0640*/ IADD3 R4, P3, R4, 0x80, RZ ; /* 0x0000008004047810 */
/* 0x000fc60007f7e0ff */
/*0650*/ IMAD.X R7, RZ, RZ, R7, P2 ; /* 0x000000ffff077224 */
/* 0x000fe400010e0607 */
/*0660*/ IMAD.X R6, RZ, RZ, R6, P3 ; /* 0x000000ffff067224 */
/* 0x000fe200018e0606 */
/*0670*/ HMMA.16816.F16 R16, R8.reuse, R20, R16 ; /* 0x000000140810723c */
/* 0x048f700000000810 */
/*0680*/ HMMA.16816.F16 R24, R8, R26, R24 ; /* 0x0000001a0818723c */
/* 0x000f5e0000000818 */
/*0690*/ NOP ; /* 0x0000000000007918 */
/* 0x000fc20000000000 */
/*06a0*/ HMMA.16816.F16 R16, R12.reuse, R28, R16 ; /* 0x0000001c0c10723c */
/* 0x064b700000000810 */
/*06b0*/ HMMA.16816.F16 R20, R12, R22, R24 ; /* 0x000000160c14723c */
/* 0x010b620000000818 */
/*06c0*/ @P1 BRA 0x350 ; /* 0xfffffc8000001947 */
/* 0x000fce000383ffff */
/*06d0*/ ISETP.GT.U32.AND P1, PT, R2, 0x1, PT ; /* 0x000000010200780c */
/* 0x000fc80003f24070 */
/*06e0*/ ISETP.GT.AND.EX P1, PT, R3, RZ, PT, P1 ; /* 0x000000ff0300720c */
/* 0x000fda0003f24310 */
/*06f0*/ @!P1 BRA 0x980 ; /* 0x0000028000009947 */
/* 0x000fea0003800000 */
/*0700*/ S2R R9, SR_LANEID ; /* 0x0000000000097919 */
/* 0x000e220000000000 */
/*0710*/ SHF.R.U32.HI R10, RZ, 0x1, R0 ; /* 0x00000001ff0a7819 */
/* 0x000fe40000011600 */
/*0720*/ LOP3.LUT R8, R9, 0x3, RZ, 0xc0, !PT ; /* 0x0000000309087812 */
/* 0x001fe400078ec0ff */
/*0730*/ SHF.R.U32.HI R11, RZ, 0x2, R9 ; /* 0x00000002ff0b7819 */
/* 0x000fe20000011609 */
/*0740*/ IMAD.MOV.U32 R9, RZ, RZ, RZ ; /* 0x000000ffff097224 */
/* 0x000fc800078e00ff */
/*0750*/ IMAD.WIDE.U32 R8, R11, R10, R8 ; /* 0x0000000a0b087225 */
/* 0x000fca00078e0008 */
/*0760*/ LEA R36, P0, R8.reuse, R4, 0x2 ; /* 0x0000000408247211 */
/* 0x040fe400078010ff */
/*0770*/ LEA R38, P1, R8.reuse, R5, 0x2 ; /* 0x0000000508267211 */
/* 0x040fe400078210ff */
/*0780*/ LEA.HI.X R37, R8.reuse, R6, R9.reuse, 0x2, P0 ; /* 0x0000000608257211 */
/* 0x140fe400000f1409 */
/*0790*/ LEA.HI.X R39, R8, R7, R9, 0x2, P1 ; /* 0x0000000708277211 */
/* 0x000fc600008f1409 */
/*07a0*/ IMAD.WIDE.U32 R34, R0.reuse, 0x10, R36 ; /* 0x0000001000227825 */
/* 0x040fe200078e0024 */
/*07b0*/ LDG.E R8, [R36.64] ; /* 0x0000000424087981 */
/* 0x000ea6000c1e1900 */
/*07c0*/ IMAD.WIDE.U32 R30, R0, 0x10, R38 ; /* 0x00000010001e7825 */
/* 0x000fe200078e0026 */
/*07d0*/ LDG.E R22, [R38.64] ; /* 0x0000000426167981 */
/* 0x020ea8000c1e1900 */
/*07e0*/ LDG.E R23, [R38.64+0x10] ; /* 0x0000100426177981 */
/* 0x000ea8000c1e1900 */
/*07f0*/ LDG.E R10, [R36.64+0x10] ; /* 0x00001004240a7981 */
/* 0x000ea8000c1e1900 */
/*0800*/ LDG.E R24, [R30.64] ; /* 0x000000041e187981 */
/* 0x000ee8000c1e1900 */
/*0810*/ LDG.E R9, [R34.64] ; /* 0x0000000422097981 */
/* 0x000ea8000c1e1900 */
/*0820*/ LDG.E R11, [R34.64+0x10] ; /* 0x00001004220b7981 */
/* 0x000ea8000c1e1900 */
/*0830*/ LDG.E R25, [R30.64+0x10] ; /* 0x000010041e197981 */
/* 0x000ee8000c1e1900 */
/*0840*/ LDG.E R26, [R38.64+0x20] ; /* 0x00002004261a7981 */
/* 0x000f28000c1e1900 */
/*0850*/ LDG.E R27, [R38.64+0x30] ; /* 0x00003004261b7981 */
/* 0x000f28000c1e1900 */
/*0860*/ LDG.E R12, [R36.64+0x20] ; /* 0x00002004240c7981 */
/* 0x000f28000c1e1900 */
/*0870*/ LDG.E R14, [R36.64+0x30] ; /* 0x00003004240e7981 */
/* 0x000f28000c1e1900 */
/*0880*/ LDG.E R13, [R34.64+0x20] ; /* 0x00002004220d7981 */
/* 0x000f28000c1e1900 */
/*0890*/ LDG.E R15, [R34.64+0x30] ; /* 0x00003004220f7981 */
/* 0x000f28000c1e1900 */
/*08a0*/ LDG.E R28, [R30.64+0x20] ; /* 0x000020041e1c7981 */
/* 0x000f28000c1e1900 */
/*08b0*/ LDG.E R29, [R30.64+0x30] ; /* 0x000030041e1d7981 */
/* 0x000f22000c1e1900 */
/*08c0*/ IADD3 R4, P2, R4, 0x40, RZ ; /* 0x0000004004047810 */
/* 0x000fc40007f5e0ff */
/*08d0*/ IADD3 R5, P1, R5, 0x40, RZ ; /* 0x0000004005057810 */
/* 0x000fe40007f3e0ff */
/*08e0*/ IADD3 R2, P3, R2, -0x2, RZ ; /* 0xfffffffe02027810 */
/* 0x000fe20007f7e0ff */
/*08f0*/ IMAD.X R6, RZ, RZ, R6, P2 ; /* 0x000000ffff067224 */
/* 0x000fe200010e0606 */
/*0900*/ PLOP3.LUT P0, PT, PT, PT, PT, 0x8, 0x0 ; /* 0x000000000000781c */
/* 0x000fe20003f0e170 */
/*0910*/ IMAD.X R7, RZ, RZ, R7, P1 ; /* 0x000000ffff077224 */
/* 0x000fe200008e0607 */
/*0920*/ IADD3.X R3, R3, -0x1, RZ, P3, !PT ; /* 0xffffffff03037810 */
/* 0x000fe20001ffe4ff */
/*0930*/ HMMA.16816.F16 R16, R8.reuse, R22, R16 ; /* 0x000000160810723c */
/* 0x044f700000000810 */
/*0940*/ HMMA.16816.F16 R20, R8, R24, R20 ; /* 0x000000180814723c */
/* 0x008f5e0000000814 */
/*0950*/ NOP ; /* 0x0000000000007918 */
/* 0x000fc20000000000 */
/*0960*/ HMMA.16816.F16 R16, R12.reuse, R26, R16 ; /* 0x0000001a0c10723c */
/* 0x070b700000000810 */
/*0970*/ HMMA.16816.F16 R20, R12, R28, R20 ; /* 0x0000001c0c14723c */
/* 0x000b500000000814 */
/*0980*/ ISETP.NE.U32.AND P1, PT, R2, RZ, PT ; /* 0x000000ff0200720c */
/* 0x000fc80003f25070 */
/*0990*/ ISETP.NE.OR.EX P0, PT, R3, RZ, P0, P1 ; /* 0x000000ff0300720c */
/* 0x000fda0000705710 */
/*09a0*/ @!P0 BRA 0xbc0 ; /* 0x0000021000008947 */
/* 0x000fea0003800000 */
/*09b0*/ S2R R9, SR_LANEID ; /* 0x0000000000097919 */
/* 0x000e220000000000 */
/*09c0*/ SHF.R.U32.HI R10, RZ, 0x1, R0 ; /* 0x00000001ff0a7819 */
/* 0x000fe40000011600 */
/*09d0*/ LOP3.LUT R8, R9, 0x3, RZ, 0xc0, !PT ; /* 0x0000000309087812 */
/* 0x001fe400078ec0ff */
/*09e0*/ SHF.R.U32.HI R11, RZ, 0x2, R9 ; /* 0x00000002ff0b7819 */
/* 0x000fe20000011609 */
/*09f0*/ IMAD.MOV.U32 R9, RZ, RZ, RZ ; /* 0x000000ffff097224 */
/* 0x000fc800078e00ff */
/*0a00*/ IMAD.WIDE.U32 R8, R11, R10, R8 ; /* 0x0000000a0b087225 */
/* 0x000fc800078e0008 */
/*0a10*/ IMAD.SHL.U32 R30, R8.reuse, 0x4, RZ ; /* 0x00000004081e7824 */
/* 0x040fe200078e00ff */
/*0a20*/ SHF.L.U64.HI R19, R8, 0x2, R9 ; /* 0x0000000208137819 */
/* 0x000fc80000010209 */
/*0a30*/ IADD3 R24, P0, R4, R30.reuse, RZ ; /* 0x0000001e04187210 */
/* 0x0a0fe40007f1e0ff */
/*0a40*/ IADD3 R26, P1, R5, R30, RZ ; /* 0x0000001e051a7210 */
/* 0x000fc60007f3e0ff */
/*0a50*/ IMAD.X R25, R6, 0x1, R19.reuse, P0 ; /* 0x0000000106197824 */
/* 0x100fe400000e0613 */
/*0a60*/ IMAD.X R27, R7, 0x1, R19, P1 ; /* 0x00000001071b7824 */
/* 0x000fe400008e0613 */
/*0a70*/ IMAD.WIDE.U32 R14, R0.reuse, 0x10, R24 ; /* 0x00000010000e7825 */
/* 0x040fe200078e0018 */
/*0a80*/ LDG.E R8, [R24.64] ; /* 0x0000000418087981 */
/* 0x000ea6000c1e1900 */
/*0a90*/ IMAD.WIDE.U32 R22, R0, 0x10, R26 ; /* 0x0000001000167825 */
/* 0x000fe200078e001a */
/*0aa0*/ LDG.E R12, [R26.64] ; /* 0x000000041a0c7981 */
/* 0x000ea8000c1e1900 */
/*0ab0*/ LDG.E R13, [R26.64+0x10] ; /* 0x000010041a0d7981 */
/* 0x000ea8000c1e1900 */
/*0ac0*/ LDG.E R10, [R24.64+0x10] ; /* 0x00001004180a7981 */
/* 0x000ea8000c1e1900 */
/*0ad0*/ LDG.E R9, [R14.64] ; /* 0x000000040e097981 */
/* 0x000ea8000c1e1900 */
/*0ae0*/ LDG.E R11, [R14.64+0x10] ; /* 0x000010040e0b7981 */
/* 0x000ea8000c1e1900 */
/*0af0*/ LDG.E R28, [R22.64] ; /* 0x00000004161c7981 */
/* 0x000ee8000c1e1900 */
/*0b00*/ LDG.E R29, [R22.64+0x10] ; /* 0x00001004161d7981 */
/* 0x000ee2000c1e1900 */
/*0b10*/ IADD3 R2, P0, R2, -0x1, RZ ; /* 0xffffffff02027810 */
/* 0x000fc80007f1e0ff */
/*0b20*/ IADD3.X R3, R3, -0x1, RZ, P0, !PT ; /* 0xffffffff03037810 */
/* 0x000fe400007fe4ff */
/*0b30*/ ISETP.NE.U32.AND P0, PT, R2, RZ, PT ; /* 0x000000ff0200720c */
/* 0x000fc80003f05070 */
/*0b40*/ ISETP.NE.AND.EX P0, PT, R3, RZ, PT, P0 ; /* 0x000000ff0300720c */
/* 0x000fe40003f05300 */
/*0b50*/ IADD3 R5, P1, R5, 0x20, RZ ; /* 0x0000002005057810 */
/* 0x000fe40007f3e0ff */
/*0b60*/ IADD3 R4, P2, R4, 0x20, RZ ; /* 0x0000002004047810 */
/* 0x000fc60007f5e0ff */
/*0b70*/ IMAD.X R7, RZ, RZ, R7, P1 ; /* 0x000000ffff077224 */
/* 0x000fe400008e0607 */
/*0b80*/ IMAD.X R6, RZ, RZ, R6, P2 ; /* 0x000000ffff067224 */
/* 0x000fe200010e0606 */
/*0b90*/ HMMA.16816.F16 R16, R8.reuse, R12, R16 ; /* 0x0000000c0810723c */
/* 0x044b700000000810 */
/*0ba0*/ HMMA.16816.F16 R20, R8, R28, R20 ; /* 0x0000001c0814723c */
/* 0x008b620000000814 */
/*0bb0*/ @P0 BRA 0xa30 ; /* 0xfffffe7000000947 */
/* 0x000fce000383ffff */
/*0bc0*/ S2R R2, SR_LANEID ; /* 0x0000000000027919 */
/* 0x000e220000000000 */
/*0bd0*/ IMAD.MOV.U32 R7, RZ, RZ, c[0x0][0x180] ; /* 0x00006000ff077624 */
/* 0x000fc400078e00ff */
/*0be0*/ IMAD.MOV.U32 R33, RZ, RZ, RZ ; /* 0x000000ffff217224 */
/* 0x000fe400078e00ff */
/*0bf0*/ IMAD.MOV.U32 R3, RZ, RZ, RZ ; /* 0x000000ffff037224 */
/* 0x000fe200078e00ff */
/*0c00*/ SHF.R.U32.HI R7, RZ, 0x1, R7 ; /* 0x00000001ff077819 */
/* 0x000fe20000011607 */
/*0c10*/ IMAD.WIDE.U32 R32, R18, c[0x0][0x180], R32 ; /* 0x0000600012207a25 */
/* 0x000fc800078e0020 */
/*0c20*/ IMAD R33, R18, c[0x0][0x184], R33 ; /* 0x0000610012217a24 */
/* 0x000fe200078e0221 */
/*0c30*/ LEA R5, P0, R32, c[0x0][0x170], 0x1 ; /* 0x00005c0020057a11 */
/* 0x000fc800078008ff */
/*0c40*/ LEA.HI.X R33, R32, c[0x0][0x174], R33, 0x1, P0 ; /* 0x00005d0020217a11 */
/* 0x000fe400000f0c21 */
/*0c50*/ SHF.R.U32.HI R0, RZ, 0x2, R2 ; /* 0x00000002ff007819 */
/* 0x001fe40000011602 */
/*0c60*/ LOP3.LUT R2, R2, 0x3, RZ, 0xc0, !PT ; /* 0x0000000302027812 */
/* 0x000fca00078ec0ff */
/*0c70*/ IMAD.WIDE.U32 R2, R7, R0, R2 ; /* 0x0000000007027225 */
/* 0x000fca00078e0002 */
/*0c80*/ LEA R4, P0, R2, R5, 0x2 ; /* 0x0000000502047211 */
/* 0x000fc800078010ff */
/*0c90*/ LEA.HI.X R5, R2, R33, R3, 0x2, P0 ; /* 0x0000002102057211 */
/* 0x000fca00000f1403 */
/*0ca0*/ IMAD.WIDE.U32 R2, R7, 0x20, R4 ; /* 0x0000002007027825 */
/* 0x000fe200078e0004 */
/*0cb0*/ STG.E [R4.64], R16 ; /* 0x0000001004007986 */
/* 0x020fe8000c101904 */
/*0cc0*/ STG.E [R2.64], R17 ; /* 0x0000001102007986 */
/* 0x000fe8000c101904 */
/*0cd0*/ STG.E [R4.64+0x10], R20 ; /* 0x0000101404007986 */
/* 0x000fe8000c101904 */
/*0ce0*/ STG.E [R2.64+0x10], R21 ; /* 0x0000101502007986 */
/* 0x000fe2000c101904 */
/*0cf0*/ EXIT ; /* 0x000000000000794d */
/* 0x000fea0003800000 */
/*0d00*/ BRA 0xd00; /* 0xfffffff000007947 */
/* 0x000fc0000383ffff */
/*0d10*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*0d20*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*0d30*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*0d40*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*0d50*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*0d60*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*0d70*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*0d80*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*0d90*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*0da0*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*0db0*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*0dc0*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*0dd0*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*0de0*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*0df0*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
..........
我们依然主要关注WMMA相关的SASS指令,如下所示。可以发现WMMA161616在底层是通过两个HMMA16816指令实现,同样地,SASS指令也是Nvidia提供的另一种调用Tensor Core的编程方法。
HMMA.16816.F16
在Nvidia Tensor Core初探中提到Nvidia提供了四种调用Tensor Core的编程方法,这里提到了三种,还有一种是MMA PTX指令,其中MMA16816 PTX指令底层实现即是HMMA16816指令,后续会在MMA PTX相关文章中提及。
4 其他
4.1 HGEMM优化
学习WMMA API的目标在于调用Tensor Core优化HGEMM,相比于cublas,WMMA的性能究竟如何?可以参考开源代码cuda_hgemm。