对于4x4矩阵。计算一次 FMA(乘累加)为一次运算,而各读取 A 和B中一个元素为1+1=2次运算。访存比为1/2。
而若一个 thread 并不只计算一个结果,而是计算 4x4=16个结果,就要从A和B中分别取出4个数据,共8个数据。访存比变为16/8=2,是上面的4倍。
上面是使用一个block来计算一个完整矩阵的情况,对于更大的矩阵,需要用多个block:
A的每个block的大小为TILE_M*TILE_K,B的每个block的大小为TILE_K*TILE_N:
注意,这里将全局内存中的A矩阵存入共享内存smemA中时进行了矩阵转置。
然后,需要从共享内存中取出A和B矩阵用于计算,每个线程分别从A和B中取出4*1的矩阵来进行计算,得到C矩阵:
C矩阵为128*128大小的矩阵。C矩阵被分成了四份,每份的尺寸都为4*4,使用同一个线程计算这四份4*4大小区域的FMA计算。
最后,利用 Prefetch 的思想,隐藏 Global Memory 读入中间寄存器、将来自 Global Memory 的数据块写入 Shared Memory、从 Shared Memory 中读出数据块的访存延迟,以免计算单元因为 stall 而空闲太久,最终的伪代码如下所示:
#define TILE_K 16
__shared__ float4 smemA[2][TILE_K * 128 / 4];
__shared__ float4 smemB[2][TILE_K * 128 / 4];
float4 c[8][2] = {{make_float4(0.f, 0.f, 0.f, 0.f)}};
float4 ldg_a_reg[2];
float4 ldg_b_reg[2];
float4 a_reg[2][2];
float4 b_reg[2][2];
// transfer first tile from global mem to shared mem
load_gmem_tile_to_reg(A, 0, ldg_a_reg);
load_gmem_tile_to_reg(B, 0, ldg_b_reg);
store_reg_to_smem_tile_transpose(ldg_a_reg, 0, smemA[0]);
store_reg_to_smem_tile(ldg_b_reg, 0, smemB[0]);
__syncthreads();
// load first tile from shared mem to register
load_smem_tile_to_reg(smemA[0], 0, a_reg[0]);
load_smem_tile_to_reg(smemB[0], 0, b_reg[0]);
int write_stage_idx = 1; //ping pong switch
do {
i += TILE_K;
// load next tile from global mem
load_gmem_tile_to_reg(A, i, ldg_a_reg);
load_gmem_tile_to_reg(B, i, ldg_b_reg);
int load_stage_idx = write_stage_idx ^ 1;
#pragma unroll
for(int j = 0; j < TILE_K - 1; ++j) {
// load next tile from shared mem to register
load_smem_tile_to_reg(smemA[load_stage_idx], j + 1, a_reg[(j + 1) % 2]);
load_smem_tile_to_reg(smemB[load_stage_idx], j + 1, b_reg[(j + 1) % 2]);
// compute matrix multiply accumulate 8x8
mma8x8(a_reg[j % 2], b_reg[j % 2], c);
}
if(i < K) {
// store next tile to shared mem
store_reg_to_smem_tile_transpose(ldg_a_reg, 0, smemA[write_stage_idx]);
store_reg_to_smem_tile(ldg_b_reg, 0, smemB[write_stage_idx]);
// use double buffer, only need one sync
__syncthreads();
// switch
write_stage_idx ^= 1;
}
// load first tile from shared mem to register of next iter
load_smem_tile_to_reg(smemA[load_stage_idx ^ 1], 0, a_reg[0]);
load_smem_tile_to_reg(smemB[load_stage_idx ^ 1], 0, b_reg[0]);
// compute last tile mma 8x8
mma8x8(a_reg[1], b_reg[1], c);
} while (i < K);
store_c(c, C);