#define A(i,j) A[(i) + (j)*lda]
#define B(i,j) B[(i) + (j)*ldb]
#define C(i,j) C[(i) + (j)*ldc]
#define sa(i,j) sa[((i)<<5) + (j)]
#define sb(i,j) sb[((i)<<5) + (j)]
#define MS 32
#define NS 32
#define KS 32
// cache blocking version, without register-level data re-use
// save one living register ty.
__global__ __launch_bounds__(1024)
void mysgemm_v3(int M, int N, int K, float alpha, float* A, float* B, float beta, float* C){
int lda = M, ldb = K, ldc = M;
int tx = threadIdx.x;
int bx = blockIdx.x, by = blockIdx.y;
int row = tx&31, col = tx>>5;
A = &A((bx<<5),0);
B = &B(0,(by<<5));
C = &C((bx<<5),(by<<5));
__shared__ float sa[MS*KS];
__shared__ float sb[KS*NS];
float tmp=0.;
for (int k_count = 0; k_count<K; k_count+=KS){
sa(row,col)=A(row,col);
sb(col,row)=B(row,col);
A+=(lda<<5);B+=32;
__syncthreads();
for (int inner_k_count=0;inner_k_count<KS;inner_k_count++){
tmp += sa(row,inner_k_count) * sb(col,inner_k_count);
}
__syncthreads();
}
C(row,col) = alpha * tmp + beta*C(row,col);
}
// optimize sgemm
#include <stdio.h>
#include <stdlib.h>
#include "assert.h"
// CUDA runtime
#include <cuda_runtime.h>
#include <cublas_v2.h>
// cal offset from row col and ld , in row-major matrix, ld is the width of the matrix
#define OFFSET(row, col, ld) ((row) * (ld) + (col))
// transfer float4
#define FETCH_FLOAT4(pointer) (reinterpret_cast<float4*>(&(pointer))[0])
#define checkCudaErrors(func) \
{ \
cudaError_t e = (func); \
if(e != cudaSuccess) \
printf ("%s %d CUDA: %s\n", __FILE__, __LINE__, cudaGetErrorString(e)); \
}
// K: ldA
// N: ldB
template <
const int BLOCK_SIZE_M, // height of block of C that each thread block calculate
const int BLOCK_SIZE_K, // width of block of A that each thread block load into shared memory
const int BLOCK_SIZE_N, // width of block of C that each thread block calculate
const int THREAD_SIZE_Y, // height of block of C that each thread calculate
const int THREAD_SIZE_X, // width of block of C that each thread calculate
const bool ENABLE_DOUBLE_BUFFER // whether enable double buffering or not
>
__global__ void Sgemm(
float* __restrict__ A,
float* __restrict__ B,
float* __restrict__ C,
const int M,
const int N,
const int K) {
// Block index
int bx = blockIdx.x;
int by = blockIdx.y;
// Thread index
int tx = threadIdx.x;
int ty = threadIdx.y;
// the threads number in Block of X,Y
const int THREAD_X_PER_BLOCK = BLOCK_SIZE_N / THREAD_SIZE_X;
const int THREAD_Y_PER_BLOCK = BLOCK_SIZE_M / THREAD_SIZE_Y;
const int THREAD_NUM_PER_BLOCK = THREAD_X_PER_BLOCK * THREAD_Y_PER_BLOCK;
// thread id in cur Block
const int tid = ty * THREAD_X_PER_BLOCK + tx;
// shared memory
__shared__ float As[2][BLOCK_SIZE_K][BLOCK_SIZE_M];
__shared__ float Bs[2][BLOCK_SIZE_K][BLOCK_SIZE_N];
// registers for C
float accum[THREAD_SIZE_Y][THREAD_SIZE_X];
#pragma unroll
for (int i = 0; i < THREAD_SIZE_Y; i++) {
#pragma unroll
for (int j = 0; j < THREAD_SIZE_X; j++) {
accum[i][j] = 0.0;
}
}
// registers for A and B
float frag_a[2][THREAD_SIZE_Y];
float frag_b[2][THREAD_SIZE_X];
// registers load global memory
const int ldg_num_a = BLOCK_SIZE_M * BLOCK_SIZE_K / (THREAD_NUM_PER_BLOCK * 4);
const int ldg_num_b = BLOCK_SIZE_K * BLOCK_SIZE_N / (THREAD_NUM_PER_BLOCK * 4);
float ldg_a_reg[4 * ldg_num_a];
float ldg_b_reg[4 * ldg_num_b];
// threads number in one row
const int A_TILE_THREAD_PER_ROW = BLOCK_SIZE_K / 4;
const int B_TILE_THREAD_PER_ROW = BLOCK_SIZE_N / 4;
// row number and col number that needs to be loaded by this thread
const int A_TILE_ROW_START = tid / A_TILE_THREAD_PER_ROW;
const int B_TILE_ROW_START = tid / B_TILE_THREAD_PER_ROW;
const int A_TILE_COL = tid % A_TILE_THREAD_PER_ROW * 4;
const int B_TILE_COL = tid % B_TILE_THREAD_PER_ROW * 4;
// row stride that thread uses to load multiple rows of a tile
const int A_TILE_ROW_STRIDE = THREAD_NUM_PER_BLOCK / A_TILE_THREAD_PER_ROW;
const int B_TILE_ROW_STRIDE = THREAD_NUM_PER_BLOCK / B_TILE_THREAD_PER_ROW;
A = &A[(BLOCK_SIZE_M * by) * K];
B = &B[BLOCK_SIZE_N * bx];
//load index of the tile
const int warp_id = tid / 32;
const int lane_id = tid % 32;
const int a_tile_index = warp_id / 2 * 16 + lane_id / 8 * 4; //warp_id * 8 + (lane_id / 16)*4; // (warp_id/4)*32 + ((lane_id%16)/2)*4;
const int b_tile_index = warp_id % 2 * 32 + lane_id % 8 * 4; //(lane_id % 16) * 4; // (warp_id%4)*16 + (lane_id/16)*8 + (lane_id%2)*4;
//transfer first tile from global mem to shared mem
// load A from global memory to shared memory
#pragma unroll
for (int i = 0; i < BLOCK_SIZE_M; i += A_TILE_ROW_STRIDE) {
int ldg_index = i / A_TILE_ROW_STRIDE * 4;
FETCH_FLOAT4(ldg_a_reg[ldg_index]) = FETCH_FLOAT4(A[OFFSET(
A_TILE_ROW_START + i, // row
A_TILE_COL, // col
K)]);
As[0][A_TILE_COL][A_TILE_ROW_START + i] = ldg_a_reg[ldg_index];
As[0][A_TILE_COL + 1][A_TILE_ROW_START + i] = ldg_a_reg[ldg_index + 1];
As[0][A_TILE_COL + 2][A_TILE_ROW_START + i] = ldg_a_reg[ldg_index + 2];
As[0][A_TILE_COL + 3][A_TILE_ROW_START + i] = ldg_a_reg[ldg_index + 3];
}
// load B from global memory to shared memory
#pragma unroll
for (int i = 0; i < BLOCK_SIZE_K; i += B_TILE_ROW_STRIDE) {
FETCH_FLOAT4(Bs[0][B_TILE_ROW_START + i][B_TILE_COL]) = FETCH_FLOAT4(B[OFFSET(
B_TILE_ROW_START + i, // row
B_TILE_COL, // col
N)]);
}
__syncthreads();
// load A from shared memory to register
FETCH_FLOAT4(frag_a[0][0]) = FETCH_FLOAT4(As[0][0][a_tile_index]);
FETCH_FLOAT4(frag_a[0][4]) = FETCH_FLOAT4(As[0][0][a_tile_index + 64]);
// load B from shared memory to register
FETCH_FLOAT4(frag_b[0][0]) = FETCH_FLOAT4(Bs[0][0][b_tile_index]);
FETCH_FLOAT4(frag_b[0][4]) = FETCH_FLOAT4(Bs[0][0][b_tile_index + 64]);
int write_stage_idx = 1;
int tile_idx = 0;
do {
// next tile index
tile_idx += BLOCK_SIZE_K;
// load next tile from global mem
if (tile_idx < K) {
#pragma unroll
for (int i = 0; i < BLOCK_SIZE_M; i += A_TILE_ROW_STRIDE) {
int ldg_index = i / A_TILE_ROW_STRIDE * 4;
FETCH_FLOAT4(ldg_a_reg[ldg_index]) = FETCH_FLOAT4(A[OFFSET(
A_TILE_ROW_START + i, // row
A_TILE_COL + tile_idx, // col
K)]);
}
#pragma unroll
for (int i = 0; i < BLOCK_SIZE_K; i += B_TILE_ROW_STRIDE) {
int ldg_index = i / B_TILE_ROW_STRIDE * 4;
FETCH_FLOAT4(ldg_b_reg[ldg_index]) = FETCH_FLOAT4(B[OFFSET(
tile_idx + B_TILE_ROW_START + i, // row
B_TILE_COL, // col
N)]);
}
}
int load_stage_idx = write_stage_idx ^ 1;
#pragma unroll
for (int j = 0; j < BLOCK_SIZE_K - 1; ++j) {
// load next tile from shared mem to register
// load A from shared memory to register
FETCH_FLOAT4(frag_a[(j + 1) % 2][0]) = FETCH_FLOAT4(As[load_stage_idx][(j + 1)][a_tile_index]);
FETCH_FLOAT4(frag_a[(j + 1) % 2][4]) = FETCH_FLOAT4(As[load_stage_idx][(j + 1)][a_tile_index + 64]);
// load B from shared memory to register
FETCH_FLOAT4(frag_b[(j + 1) % 2][0]) = FETCH_FLOAT4(Bs[load_stage_idx][(j + 1)][b_tile_index]);
FETCH_FLOAT4(frag_b[(j + 1) % 2][4]) = FETCH_FLOAT4(Bs[load_stage_idx][(j + 1)][b_tile_index + 64]);
// compute C THREAD_SIZE_X x THREAD_SIZE_Y
#pragma unroll
for (int thread_y = 0; thread_y < THREAD_SIZE_Y; ++thread_y) {
#pragma unroll
for (int thread_x = 0; thread_x < THREAD_SIZE_X; ++thread_x) {
accum[thread_y][thread_x] += frag_a[j % 2][thread_y] * frag_b[j % 2][thread_x];
}
}
}
if (tile_idx < K) {
// load A from global memory to shared memory
#pragma unroll
for (int i = 0; i < BLOCK_SIZE_M; i += A_TILE_ROW_STRIDE) {
int ldg_index = i / A_TILE_ROW_STRIDE * 4;
As[write_stage_idx][A_TILE_COL][A_TILE_ROW_START + i] = ldg_a_reg[ldg_index];
As[write_stage_idx][A_TILE_COL + 1][A_TILE_ROW_START + i] = ldg_a_reg[ldg_index + 1];
As[write_stage_idx][A_TILE_COL + 2][A_TILE_ROW_START + i] = ldg_a_reg[ldg_index + 2];
As[write_stage_idx][A_TILE_COL + 3][A_TILE_ROW_START + i] = ldg_a_reg[ldg_index + 3];
}
// load B from global memory to shared memory
#pragma unroll
for (int i = 0; i < BLOCK_SIZE_K; i += B_TILE_ROW_STRIDE) {
int ldg_index = i / B_TILE_ROW_STRIDE * 4;
FETCH_FLOAT4(Bs[write_stage_idx][B_TILE_ROW_START + i][B_TILE_COL]) = FETCH_FLOAT4(ldg_b_reg[ldg_index]);
}
// use double buffer, only need one sync
__syncthreads();
// switch
write_stage_idx ^= 1;
}
// load first tile from shared mem to register of next iter
// load A from shared memory to register
FETCH_FLOAT4(frag_a[0][0]) = FETCH_FLOAT4(As[load_stage_idx ^ 1][0][a_tile_index]);
FETCH_FLOAT4(frag_a[0][4]) = FETCH_FLOAT4(As[load_stage_idx ^ 1][0][a_tile_index + 64]);
// load B from shared memory to register
FETCH_FLOAT4(frag_b[0][0]) = FETCH_FLOAT4(Bs[load_stage_idx ^ 1][0][b_tile_index]);
FETCH_FLOAT4(frag_b[0][4]) = FETCH_FLOAT4(Bs[load_stage_idx ^ 1][0][b_tile_index + 64]);
// compute C THREAD_SIZE_X x THREAD_SIZE_Y
#pragma unroll
for (int thread_y = 0; thread_y < THREAD_SIZE_Y; ++thread_y) {
#pragma unroll
for (int thread_x = 0; thread_x < THREAD_SIZE_X; ++thread_x) {
accum[thread_y][thread_x] += frag_a[1][thread_y] * frag_b[1][thread_x];
}
}
} while (tile_idx < K);
const int c_block_row = a_tile_index;
const int c_block_col = b_tile_index;
//store C00 block
for (int i = 0; i < 4; i++) {
FETCH_FLOAT4(C[OFFSET(
BLOCK_SIZE_M * by + c_block_row + i,
BLOCK_SIZE_N * bx + c_block_col,
N)]) = FETCH_FLOAT4(accum[i][0]);
}
//store C01 block
for (int i = 0; i < 4; i++) {
FETCH_FLOAT4(C[OFFSET(
BLOCK_SIZE_M * by + c_block_row + i,
BLOCK_SIZE_N * bx + c_block_col + 64,
N)]) = FETCH_FLOAT4(accum[i][4]);
}
//store C10 block
for (int i = 0; i < 4; i++) {
FETCH_FLOAT4(C[OFFSET(
BLOCK_SIZE_M * by + c_block_row + 64 + i,
BLOCK_SIZE_N * bx + c_block_col,
N)]) = FETCH_FLOAT4(accum[i + 4][0]);
}
//store C11 block
for (int i = 0; i < 4; i++) {
FETCH_FLOAT4(C[OFFSET(
BLOCK_SIZE_M * by + c_block_row + 64 + i,
BLOCK_SIZE_N * bx + c_block_col + 64,
N)]) = FETCH_FLOAT4(accum[i + 4][4]);
}
}
int main(int argc, char** argv) {
size_t M = 512;
size_t K = 512;
size_t N = 512;
assert(M % 8 == 0);
assert(N % 8 == 0);
assert(K % 8 == 0);
size_t bytes_A = sizeof(float) * M * K;
size_t bytes_B = sizeof(float) * K * N;
size_t bytes_C = sizeof(float) * M * N;
float* h_A = (float*)malloc(bytes_A);
float* h_B = (float*)malloc(bytes_B);
float* h_C = (float*)malloc(bytes_C);
float* h_C1 = (float*)malloc(bytes_C);
float* d_A;
float* d_B;
float* d_C;
checkCudaErrors(cudaMalloc(&d_A, bytes_A));
checkCudaErrors(cudaMalloc(&d_B, bytes_B));
checkCudaErrors(cudaMalloc(&d_C, bytes_C));
double msecPerMatrixMul[2] = { 0, 0 };
double gigaFlops[2] = { 0, 0 };
double flopsPerMatrixMul = 2.0 * M * N * K;
// don't edit it
const int BLOCK_SIZE_M = 128;
const int BLOCK_SIZE_K = 8;
const int BLOCK_SIZE_N = 128;
const int THREAD_SIZE_X = 8;
const int THREAD_SIZE_Y = 8;
const bool ENABLE_DOUBLE_BUFFER = false;
// 生成A的数据
for (int i = 0; i < M * K; i++) {
h_A[i] = i / 13;
}
// 生成B的数据
for (int i = 0; i < K * N; i++) {
h_B[i] = i % 13;
}
checkCudaErrors(cudaMemcpy(d_A, h_A, bytes_A, cudaMemcpyHostToDevice));
checkCudaErrors(cudaMemcpy(d_B, h_B, bytes_B, cudaMemcpyHostToDevice));
cudaEvent_t start, stop;
checkCudaErrors(cudaEventCreate(&start));
checkCudaErrors(cudaEventCreate(&stop));
float msecTotal = 0;
int nIter = 1000;
checkCudaErrors(cudaMemcpy(d_C, h_C, bytes_C, cudaMemcpyHostToDevice));
checkCudaErrors(cudaEventRecord(start));
for (int run = 0; run < nIter; run++) {
dim3 dimBlock(BLOCK_SIZE_N / THREAD_SIZE_X, BLOCK_SIZE_M / THREAD_SIZE_Y);
dim3 dimGrid(N / BLOCK_SIZE_N, M / BLOCK_SIZE_M);
Sgemm<BLOCK_SIZE_M, BLOCK_SIZE_K, BLOCK_SIZE_N, THREAD_SIZE_Y, THREAD_SIZE_X, ENABLE_DOUBLE_BUFFER>
<< < dimGrid, dimBlock >> > (d_A, d_B, d_C, M, N, K);
}
checkCudaErrors(cudaEventRecord(stop));
checkCudaErrors(cudaEventSynchronize(stop));
checkCudaErrors(cudaEventElapsedTime(&msecTotal, start, stop));
checkCudaErrors(cudaMemcpy(h_C, d_C, bytes_C, cudaMemcpyDeviceToHost));
msecPerMatrixMul[0] = msecTotal / nIter;
gigaFlops[0] = (flopsPerMatrixMul * 1.0e-9f) / (msecPerMatrixMul[0] / 1000.0f);
printf("My gemm Performance= %.2f GFlop/s, Time= %.3f msec, Size= %.0f Ops,\n",
gigaFlops[0],
msecPerMatrixMul[0],
flopsPerMatrixMul);
// cublas
cublasHandle_t blas_handle;
cublasCreate(&blas_handle);
float alpha = 1.0;
float beta = 0;
checkCudaErrors(cudaMemcpy(d_C, h_C, bytes_C, cudaMemcpyHostToDevice));
checkCudaErrors(cudaEventRecord(start));
for (int run = 0; run < nIter; run++) {
cublasSgemm(blas_handle, CUBLAS_OP_T, CUBLAS_OP_T,
M, N, K, &alpha,
d_A, K, d_B, N, &beta, d_C, N
);
}
checkCudaErrors(cudaEventRecord(stop));
checkCudaErrors(cudaEventSynchronize(stop));
checkCudaErrors(cudaEventElapsedTime(&msecTotal, start, stop));
checkCudaErrors(cudaMemcpy(h_C1, d_C, bytes_C, cudaMemcpyDeviceToHost));
msecPerMatrixMul[1] = msecTotal / nIter;
gigaFlops[1] = (flopsPerMatrixMul * 1.0e-9f) / (msecPerMatrixMul[1] / 1000.0f);
printf("CuBlas Performance= %.2f GFlop/s, Time= %.3f msec, Size= %.0f Ops,\n",
gigaFlops[1],
msecPerMatrixMul[1],
flopsPerMatrixMul);
cublasDestroy(blas_handle);
double eps = 1.e-6; // machine zero
bool correct = true;
for (int i = 0; i < M * N; i++) {
int row = i / N;
int col = i % N;
double abs_err = fabs(h_C[i] - h_C1[col * M + row]);
double dot_length = M;
double abs_val = fabs(h_C[i]);
double rel_err = abs_err / abs_val / dot_length;
if (rel_err > eps) {
printf("Error! Matrix[%d][%d]=%.8f, ref=%.8f error term is > %E\n",
row, col, h_C[i], h_C1[col * M + row], eps);
correct = false;
break;
}
}
printf("%s\n", correct ? "Result= PASS" : "Result= FAIL");
printf("ratio= %f\n", gigaFlops[0] / gigaFlops[1]);
// Free Memory
cudaFree(d_A);
cudaFree(d_B);
cudaFree(d_C);
free(h_A);
free(h_B);
free(h_C);
free(h_C1);
}
#include "cuda_runtime.h"
#include "device_launch_parameters.h"
#define BLOCK_X 16
#define BLOCK_Y 16
#define TILE_X 128
#define TILE_X_4 32
#define TILE_Y 128
#define TILE_Y_4 32
#define TILE_K 16
#define WPTN 8
#define WPTM 8
#define WPTN_4 2
__global__ void gemm_kernel_NN(
const float* __restrict__ A,
const float* __restrict__ B,
float4* __restrict__ C,
float alpha, float beta,
int M, int N, int K)
{
__shared__ float4 smem_a[2][TILE_K * TILE_Y_4];
__shared__ float4 smem_b[2][TILE_K * TILE_X_4];
int tx = threadIdx.x % 16;
int ty = threadIdx.x / 16;
int tx4 = threadIdx.x % 4;
int ty4 = threadIdx.x / 4;
int tx32 = threadIdx.x % 32;
int ty32 = threadIdx.x / 32;
const float* pA = (A + K * TILE_Y * blockIdx.y + ty4 * K + tx4 * 4);
const float* pB = (B + TILE_X * blockIdx.x + ty32 * N + tx32 * 4);
float4* pC = C + TILE_Y * blockIdx.y * N / 4 + TILE_X_4 * blockIdx.x;
int sts_a_offset = tx4 * 4 * TILE_Y + ty4;
int sts_b_offset = ty32 * TILE_X_4 + tx32;
float4 f4_zero = make_float4(0.f, 0.f, 0.f, 0.f);
bool valid_ld_a_0 = ((blockIdx.y * TILE_Y + ty4) < M) && ((tx4 * 4) < K);
bool valid_ld_a_1 = ((blockIdx.y * TILE_Y + ty4 + 64) < M) && ((tx4 * 4) < K);
bool valid_ld_b_0 = ((blockIdx.x * TILE_X + tx32 * 4) < N) && (ty32 < K);
bool valid_ld_b_1 = ((blockIdx.x * TILE_X + tx32 * 4) < N) && ((ty32 + 8) < K);
float4 ldg_a_reg[2];
float4 ldg_b_reg[2];
ldg_a_reg[0] = valid_ld_a_0 ? *(const float4*)pA : f4_zero;
ldg_a_reg[1] = valid_ld_a_1 ? *(const float4*)(pA + 64 * K) : f4_zero;
ldg_b_reg[0] = valid_ld_b_0 ? *(const float4*)(pB + 0 * N) : f4_zero;
ldg_b_reg[1] = valid_ld_b_1 ? *(const float4*)(pB + 8 * N) : f4_zero;
float4 c[WPTM][WPTN_4] = { { f4_zero } };
*((float*)&smem_a[0][0] + sts_a_offset + 0 * TILE_Y + 0) = ldg_a_reg[0].x;
*((float*)&smem_a[0][0] + sts_a_offset + 1 * TILE_Y + 0) = ldg_a_reg[0].y;
*((float*)&smem_a[0][0] + sts_a_offset + 2 * TILE_Y + 0) = ldg_a_reg[0].z;
*((float*)&smem_a[0][0] + sts_a_offset + 3 * TILE_Y + 0) = ldg_a_reg[0].w;
*((float*)&smem_a[0][0] + sts_a_offset + 0 * TILE_Y + 64) = ldg_a_reg[1].x;
*((float*)&smem_a[0][0] + sts_a_offset + 1 * TILE_Y + 64) = ldg_a_reg[1].y;
*((float*)&smem_a[0][0] + sts_a_offset + 2 * TILE_Y + 64) = ldg_a_reg[1].z;
*((float*)&smem_a[0][0] + sts_a_offset + 3 * TILE_Y + 64) = ldg_a_reg[1].w;
smem_b[0][sts_b_offset + 0] = ldg_b_reg[0];
smem_b[0][sts_b_offset + 8 * TILE_X_4] = ldg_b_reg[1];
__syncthreads();
int i = 0;
int write_stage_idx = 1;
float4 reg_a[2][2];
float4 reg_b[2][2];
reg_a[0][0] = smem_a[0][0 + ty];
reg_a[0][1] = smem_a[0][16 + ty];
reg_b[0][0] = smem_b[0][0 + tx];
reg_b[0][1] = smem_b[0][16 + tx];
do
{
i += 16;
valid_ld_a_0 = (valid_ld_a_0 && ((tx4 * 4 + i) < K));
valid_ld_a_1 = (valid_ld_a_1 && ((tx4 * 4 + i) < K));
valid_ld_b_0 = (valid_ld_b_0 && ((ty32 + i) < K));
valid_ld_b_1 = (valid_ld_b_1 && ((ty32 + 8 + i) < K));
ldg_a_reg[0] = (valid_ld_a_0) ? *(const float4*)(pA + i + 0) : f4_zero;
ldg_a_reg[1] = (valid_ld_a_1) ? *(const float4*)(pA + i + 64 * K) : f4_zero;
ldg_b_reg[0] = (valid_ld_b_0) ? *(const float4*)(pB + (i + 0) * N) : f4_zero;
ldg_b_reg[1] = (valid_ld_b_1) ? *(const float4*)(pB + (i + 8) * N) : f4_zero;
int load_stage_idx = write_stage_idx ^ 1;
#pragma unroll
for (int j = 0; j < TILE_K - 1; j++)
{
reg_a[(j + 1) % 2][0] = smem_a[load_stage_idx][(j + 1) * TILE_Y_4 + 0 + ty];
reg_a[(j + 1) % 2][1] = smem_a[load_stage_idx][(j + 1) * TILE_Y_4 + 16 + ty];
reg_b[(j + 1) % 2][0] = smem_b[load_stage_idx][(j + 1) * TILE_X_4 + 0 + tx];
reg_b[(j + 1) % 2][1] = smem_b[load_stage_idx][(j + 1) * TILE_X_4 + 16 + tx];
c[0][0].x += reg_a[j % 2][0].x * reg_b[j % 2][0].x;
c[0][0].y += reg_a[j % 2][0].x * reg_b[j % 2][0].y;
c[0][0].z += reg_a[j % 2][0].x * reg_b[j % 2][0].z;
c[0][0].w += reg_a[j % 2][0].x * reg_b[j % 2][0].w;
c[0][1].x += reg_a[j % 2][0].x * reg_b[j % 2][1].x;
c[0][1].y += reg_a[j % 2][0].x * reg_b[j % 2][1].y;
c[0][1].z += reg_a[j % 2][0].x * reg_b[j % 2][1].z;
c[0][1].w += reg_a[j % 2][0].x * reg_b[j % 2][1].w;
c[1][0].x += reg_a[j % 2][0].y * reg_b[j % 2][0].x;
c[1][0].y += reg_a[j % 2][0].y * reg_b[j % 2][0].y;
c[1][0].z += reg_a[j % 2][0].y * reg_b[j % 2][0].z;
c[1][0].w += reg_a[j % 2][0].y * reg_b[j % 2][0].w;
c[1][1].x += reg_a[j % 2][0].y * reg_b[j % 2][1].x;
c[1][1].y += reg_a[j % 2][0].y * reg_b[j % 2][1].y;
c[1][1].z += reg_a[j % 2][0].y * reg_b[j % 2][1].z;
c[1][1].w += reg_a[j % 2][0].y * reg_b[j % 2][1].w;
c[2][0].x += reg_a[j % 2][0].z * reg_b[j % 2][0].x;
c[2][0].y += reg_a[j % 2][0].z * reg_b[j % 2][0].y;
c[2][0].z += reg_a[j % 2][0].z * reg_b[j % 2][0].z;
c[2][0].w += reg_a[j % 2][0].z * reg_b[j % 2][0].w;
c[2][1].x += reg_a[j % 2][0].z * reg_b[j % 2][1].x;
c[2][1].y += reg_a[j % 2][0].z * reg_b[j % 2][1].y;
c[2][1].z += reg_a[j % 2][0].z * reg_b[j % 2][1].z;
c[2][1].w += reg_a[j % 2][0].z * reg_b[j % 2][1].w;
c[3][0].x += reg_a[j % 2][0].w * reg_b[j % 2][0].x;
c[3][0].y += reg_a[j % 2][0].w * reg_b[j % 2][0].y;
c[3][0].z += reg_a[j % 2][0].w * reg_b[j % 2][0].z;
c[3][0].w += reg_a[j % 2][0].w * reg_b[j % 2][0].w;
c[3][1].x += reg_a[j % 2][0].w * reg_b[j % 2][1].x;
c[3][1].y += reg_a[j % 2][0].w * reg_b[j % 2][1].y;
c[3][1].z += reg_a[j % 2][0].w * reg_b[j % 2][1].z;
c[3][1].w += reg_a[j % 2][0].w * reg_b[j % 2][1].w;
c[4][0].x += reg_a[j % 2][1].x * reg_b[j % 2][0].x;
c[4][0].y += reg_a[j % 2][1].x * reg_b[j % 2][0].y;
c[4][0].z += reg_a[j % 2][1].x * reg_b[j % 2][0].z;
c[4][0].w += reg_a[j % 2][1].x * reg_b[j % 2][0].w;
c[4][1].x += reg_a[j % 2][1].x * reg_b[j % 2][1].x;
c[4][1].y += reg_a[j % 2][1].x * reg_b[j % 2][1].y;
c[4][1].z += reg_a[j % 2][1].x * reg_b[j % 2][1].z;
c[4][1].w += reg_a[j % 2][1].x * reg_b[j % 2][1].w;
c[5][0].x += reg_a[j % 2][1].y * reg_b[j % 2][0].x;
c[5][0].y += reg_a[j % 2][1].y * reg_b[j % 2][0].y;
c[5][0].z += reg_a[j % 2][1].y * reg_b[j % 2][0].z;
c[5][0].w += reg_a[j % 2][1].y * reg_b[j % 2][0].w;
c[5][1].x += reg_a[j % 2][1].y * reg_b[j % 2][1].x;
c[5][1].y += reg_a[j % 2][1].y * reg_b[j % 2][1].y;
c[5][1].z += reg_a[j % 2][1].y * reg_b[j % 2][1].z;
c[5][1].w += reg_a[j % 2][1].y * reg_b[j % 2][1].w;
c[6][0].x += reg_a[j % 2][1].z * reg_b[j % 2][0].x;
c[6][0].y += reg_a[j % 2][1].z * reg_b[j % 2][0].y;
c[6][0].z += reg_a[j % 2][1].z * reg_b[j % 2][0].z;
c[6][0].w += reg_a[j % 2][1].z * reg_b[j % 2][0].w;
c[6][1].x += reg_a[j % 2][1].z * reg_b[j % 2][1].x;
c[6][1].y += reg_a[j % 2][1].z * reg_b[j % 2][1].y;
c[6][1].z += reg_a[j % 2][1].z * reg_b[j % 2][1].z;
c[6][1].w += reg_a[j % 2][1].z * reg_b[j % 2][1].w;
c[7][0].x += reg_a[j % 2][1].w * reg_b[j % 2][0].x;
c[7][0].y += reg_a[j % 2][1].w * reg_b[j % 2][0].y;
c[7][0].z += reg_a[j % 2][1].w * reg_b[j % 2][0].z;
c[7][0].w += reg_a[j % 2][1].w * reg_b[j % 2][0].w;
c[7][1].x += reg_a[j % 2][1].w * reg_b[j % 2][1].x;
c[7][1].y += reg_a[j % 2][1].w * reg_b[j % 2][1].y;
c[7][1].z += reg_a[j % 2][1].w * reg_b[j % 2][1].z;
c[7][1].w += reg_a[j % 2][1].w * reg_b[j % 2][1].w;
}
if(i < K) {
*((float*)&smem_a[write_stage_idx][0] + sts_a_offset + 0 * TILE_Y + 0) = ldg_a_reg[0].x;
*((float*)&smem_a[write_stage_idx][0] + sts_a_offset + 1 * TILE_Y + 0) = ldg_a_reg[0].y;
*((float*)&smem_a[write_stage_idx][0] + sts_a_offset + 2 * TILE_Y + 0) = ldg_a_reg[0].z;
*((float*)&smem_a[write_stage_idx][0] + sts_a_offset + 3 * TILE_Y + 0) = ldg_a_reg[0].w;
*((float*)&smem_a[write_stage_idx][0] + sts_a_offset + 0 * TILE_Y + 64) = ldg_a_reg[1].x;
*((float*)&smem_a[write_stage_idx][0] + sts_a_offset + 1 * TILE_Y + 64) = ldg_a_reg[1].y;
*((float*)&smem_a[write_stage_idx][0] + sts_a_offset + 2 * TILE_Y + 64) = ldg_a_reg[1].z;
*((float*)&smem_a[write_stage_idx][0] + sts_a_offset + 3 * TILE_Y + 64) = ldg_a_reg[1].w;
smem_b[write_stage_idx][sts_b_offset + 0] = ldg_b_reg[0];
smem_b[write_stage_idx][sts_b_offset + 8 * TILE_X_4] = ldg_b_reg[1];
__syncthreads();
write_stage_idx ^= 1;
}
reg_a[0][0] = smem_a[load_stage_idx ^ 1][0 + ty];
reg_a[0][1] = smem_a[load_stage_idx ^ 1][16 + ty];
reg_b[0][0] = smem_b[load_stage_idx ^ 1][0 + tx];
reg_b[0][1] = smem_b[load_stage_idx ^ 1][16 + tx];
c[0][0].x += reg_a[1][0].x * reg_b[1][0].x;
c[0][0].y += reg_a[1][0].x * reg_b[1][0].y;
c[0][0].z += reg_a[1][0].x * reg_b[1][0].z;
c[0][0].w += reg_a[1][0].x * reg_b[1][0].w;
c[0][1].x += reg_a[1][0].x * reg_b[1][1].x;
c[0][1].y += reg_a[1][0].x * reg_b[1][1].y;
c[0][1].z += reg_a[1][0].x * reg_b[1][1].z;
c[0][1].w += reg_a[1][0].x * reg_b[1][1].w;
c[1][0].x += reg_a[1][0].y * reg_b[1][0].x;
c[1][0].y += reg_a[1][0].y * reg_b[1][0].y;
c[1][0].z += reg_a[1][0].y * reg_b[1][0].z;
c[1][0].w += reg_a[1][0].y * reg_b[1][0].w;
c[1][1].x += reg_a[1][0].y * reg_b[1][1].x;
c[1][1].y += reg_a[1][0].y * reg_b[1][1].y;
c[1][1].z += reg_a[1][0].y * reg_b[1][1].z;
c[1][1].w += reg_a[1][0].y * reg_b[1][1].w;
c[2][0].x += reg_a[1][0].z * reg_b[1][0].x;
c[2][0].y += reg_a[1][0].z * reg_b[1][0].y;
c[2][0].z += reg_a[1][0].z * reg_b[1][0].z;
c[2][0].w += reg_a[1][0].z * reg_b[1][0].w;
c[2][1].x += reg_a[1][0].z * reg_b[1][1].x;
c[2][1].y += reg_a[1][0].z * reg_b[1][1].y;
c[2][1].z += reg_a[1][0].z * reg_b[1][1].z;
c[2][1].w += reg_a[1][0].z * reg_b[1][1].w;
c[3][0].x += reg_a[1][0].w * reg_b[1][0].x;
c[3][0].y += reg_a[1][0].w * reg_b[1][0].y;
c[3][0].z += reg_a[1][0].w * reg_b[1][0].z;
c[3][0].w += reg_a[1][0].w * reg_b[1][0].w;
c[3][1].x += reg_a[1][0].w * reg_b[1][1].x;
c[3][1].y += reg_a[1][0].w * reg_b[1][1].y;
c[3][1].z += reg_a[1][0].w * reg_b[1][1].z;
c[3][1].w += reg_a[1][0].w * reg_b[1][1].w;
c[4][0].x += reg_a[1][1].x * reg_b[1][0].x;
c[4][0].y += reg_a[1][1].x * reg_b[1][0].y;
c[4][0].z += reg_a[1][1].x * reg_b[1][0].z;
c[4][0].w += reg_a[1][1].x * reg_b[1][0].w;
c[4][1].x += reg_a[1][1].x * reg_b[1][1].x;
c[4][1].y += reg_a[1][1].x * reg_b[1][1].y;
c[4][1].z += reg_a[1][1].x * reg_b[1][1].z;
c[4][1].w += reg_a[1][1].x * reg_b[1][1].w;
c[5][0].x += reg_a[1][1].y * reg_b[1][0].x;
c[5][0].y += reg_a[1][1].y * reg_b[1][0].y;
c[5][0].z += reg_a[1][1].y * reg_b[1][0].z;
c[5][0].w += reg_a[1][1].y * reg_b[1][0].w;
c[5][1].x += reg_a[1][1].y * reg_b[1][1].x;
c[5][1].y += reg_a[1][1].y * reg_b[1][1].y;
c[5][1].z += reg_a[1][1].y * reg_b[1][1].z;
c[5][1].w += reg_a[1][1].y * reg_b[1][1].w;
c[6][0].x += reg_a[1][1].z * reg_b[1][0].x;
c[6][0].y += reg_a[1][1].z * reg_b[1][0].y;
c[6][0].z += reg_a[1][1].z * reg_b[1][0].z;
c[6][0].w += reg_a[1][1].z * reg_b[1][0].w;
c[6][1].x += reg_a[1][1].z * reg_b[1][1].x;
c[6][1].y += reg_a[1][1].z * reg_b[1][1].y;
c[6][1].z += reg_a[1][1].z * reg_b[1][1].z;
c[6][1].w += reg_a[1][1].z * reg_b[1][1].w;
c[7][0].x += reg_a[1][1].w * reg_b[1][0].x;
c[7][0].y += reg_a[1][1].w * reg_b[1][0].y;
c[7][0].z += reg_a[1][1].w * reg_b[1][0].z;
c[7][0].w += reg_a[1][1].w * reg_b[1][0].w;
c[7][1].x += reg_a[1][1].w * reg_b[1][1].x;
c[7][1].y += reg_a[1][1].w * reg_b[1][1].y;
c[7][1].z += reg_a[1][1].w * reg_b[1][1].z;
c[7][1].w += reg_a[1][1].w * reg_b[1][1].w;
} while (i < K);
#pragma unroll
for (int wm = 0; wm < WPTM; wm++){
#pragma unroll
for (int wn = 0; wn < WPTN_4; wn++){
c[wm][wn].x *= alpha;
c[wm][wn].y *= alpha;
c[wm][wn].z *= alpha;
c[wm][wn].w *= alpha;
}
}
#pragma unroll
for (int wm = 0; wm < 4; wm++){
#pragma unroll
for (int wn = 0; wn < WPTN_4; wn++){
if (((blockIdx.y * TILE_Y + ty * 4 + wm) < M)
&& ((blockIdx.x * TILE_X + wn * 64 + tx * 4) < N)) {
if (beta != 0) {
float4 vec4c = *(pC + ((ty * 4 + wm) * N / 4 + wn * 16 + tx));
vec4c.x = vec4c.x * beta + c[wm][wn].x;
vec4c.y = vec4c.y * beta + c[wm][wn].y;
vec4c.z = vec4c.z * beta + c[wm][wn].z;
vec4c.w = vec4c.w * beta + c[wm][wn].w;
*(pC + (ty * 4 + wm) * N / 4 + wn * 16 + tx) = vec4c;
} else {
*(pC + (ty * 4 + wm) * N / 4 + wn * 16 + tx) = c[wm][wn];
}
}
}
}
#pragma unroll
for (int wm = 0; wm < 4; wm++){
#pragma unroll
for (int wn = 0; wn < WPTN_4; wn++){
if (((blockIdx.y * TILE_Y + 64 + ty * 4 + wm) < M)
&& ((blockIdx.x * TILE_X + wn * 64 + tx * 4) < N)) {
if (beta != 0) {
float4 vec4c = *(pC + ((64 + ty * 4 + wm) * N / 4 + wn * 16 + tx));
vec4c.x = vec4c.x * beta + c[wm + 4][wn].x;
vec4c.y = vec4c.y * beta + c[wm + 4][wn].y;
vec4c.z = vec4c.z * beta + c[wm + 4][wn].z;
vec4c.w = vec4c.w * beta + c[wm + 4][wn].w;
*(pC + (64 + ty * 4 + wm) * N / 4 + wn * 16 + tx) = vec4c;
} else {
*(pC + (64 + ty * 4 + wm) * N / 4 + wn * 16 + tx) = c[wm + 4][wn];
}
}
}
}
}