sgemm

#define A(i,j) A[(i) + (j)*lda]
#define B(i,j) B[(i) + (j)*ldb]
#define C(i,j) C[(i) + (j)*ldc]
#define sa(i,j) sa[((i)<<5) + (j)]
#define sb(i,j) sb[((i)<<5) + (j)]
#define MS 32
#define NS 32
#define KS 32
// cache blocking version, without register-level data re-use
// save one living register ty.
__global__  __launch_bounds__(1024)
void mysgemm_v3(int M, int N, int K, float alpha, float* A, float* B, float beta, float* C){
    int lda = M, ldb = K, ldc = M;
    int tx = threadIdx.x;
    int bx = blockIdx.x, by = blockIdx.y;
    int row = tx&31, col = tx>>5;
    A = &A((bx<<5),0);
    B = &B(0,(by<<5));
    C = &C((bx<<5),(by<<5));
    __shared__ float sa[MS*KS];
    __shared__ float sb[KS*NS];
    float tmp=0.;
    for (int k_count = 0; k_count<K; k_count+=KS){
        sa(row,col)=A(row,col);
        sb(col,row)=B(row,col);
        A+=(lda<<5);B+=32;
        __syncthreads();
        for (int inner_k_count=0;inner_k_count<KS;inner_k_count++){
            tmp += sa(row,inner_k_count) * sb(col,inner_k_count);
        }
        __syncthreads();
    }
    C(row,col) = alpha * tmp + beta*C(row,col);
}
// optimize sgemm

#include <stdio.h>
#include <stdlib.h>
#include "assert.h" 

// CUDA runtime
#include <cuda_runtime.h>
#include <cublas_v2.h>

// cal offset from row col and ld , in row-major matrix, ld is the width of the matrix
#define OFFSET(row, col, ld) ((row) * (ld) + (col))

// transfer float4
#define FETCH_FLOAT4(pointer) (reinterpret_cast<float4*>(&(pointer))[0])

#define checkCudaErrors(func)				\
{									\
    cudaError_t e = (func);			\
    if(e != cudaSuccess)						                \
        printf ("%s %d CUDA: %s\n", __FILE__,  __LINE__, cudaGetErrorString(e));		\
}

// K: ldA
// N: ldB
template <
    const int BLOCK_SIZE_M,  // height of block of C that each thread block calculate
    const int BLOCK_SIZE_K,  // width of block of A that each thread block load into shared memory
    const int BLOCK_SIZE_N,  // width of block of C that each thread block calculate
    const int THREAD_SIZE_Y, // height of block of C that each thread calculate
    const int THREAD_SIZE_X,  // width of block of C that each thread calculate
    const bool ENABLE_DOUBLE_BUFFER // whether enable double buffering or not
>
__global__ void Sgemm(
    float* __restrict__ A,
    float* __restrict__ B,
    float* __restrict__ C,
    const int M,
    const int N,
    const int K) {
    // Block index
    int bx = blockIdx.x;
    int by = blockIdx.y;

    // Thread index
    int tx = threadIdx.x;
    int ty = threadIdx.y;

    // the threads number in Block of X,Y
    const int THREAD_X_PER_BLOCK = BLOCK_SIZE_N / THREAD_SIZE_X;
    const int THREAD_Y_PER_BLOCK = BLOCK_SIZE_M / THREAD_SIZE_Y;
    const int THREAD_NUM_PER_BLOCK = THREAD_X_PER_BLOCK * THREAD_Y_PER_BLOCK;

    // thread id in cur Block
    const int tid = ty * THREAD_X_PER_BLOCK + tx;

    // shared memory
    __shared__ float As[2][BLOCK_SIZE_K][BLOCK_SIZE_M];
    __shared__ float Bs[2][BLOCK_SIZE_K][BLOCK_SIZE_N];
    // registers for C
    float accum[THREAD_SIZE_Y][THREAD_SIZE_X];
#pragma unroll
    for (int i = 0; i < THREAD_SIZE_Y; i++) {
#pragma unroll
        for (int j = 0; j < THREAD_SIZE_X; j++) {
            accum[i][j] = 0.0;
        }
    }
    // registers for A and B
    float frag_a[2][THREAD_SIZE_Y];
    float frag_b[2][THREAD_SIZE_X];
    // registers load global memory
    const int ldg_num_a = BLOCK_SIZE_M * BLOCK_SIZE_K / (THREAD_NUM_PER_BLOCK * 4);
    const int ldg_num_b = BLOCK_SIZE_K * BLOCK_SIZE_N / (THREAD_NUM_PER_BLOCK * 4);
    float ldg_a_reg[4 * ldg_num_a];
    float ldg_b_reg[4 * ldg_num_b];

    // threads number in one row
    const int A_TILE_THREAD_PER_ROW = BLOCK_SIZE_K / 4;
    const int B_TILE_THREAD_PER_ROW = BLOCK_SIZE_N / 4;

    // row number and col number that needs to be loaded by this thread
    const int A_TILE_ROW_START = tid / A_TILE_THREAD_PER_ROW;
    const int B_TILE_ROW_START = tid / B_TILE_THREAD_PER_ROW;

    const int A_TILE_COL = tid % A_TILE_THREAD_PER_ROW * 4;
    const int B_TILE_COL = tid % B_TILE_THREAD_PER_ROW * 4;

    // row stride that thread uses to load multiple rows of a tile
    const int A_TILE_ROW_STRIDE = THREAD_NUM_PER_BLOCK / A_TILE_THREAD_PER_ROW;
    const int B_TILE_ROW_STRIDE = THREAD_NUM_PER_BLOCK / B_TILE_THREAD_PER_ROW;

    A = &A[(BLOCK_SIZE_M * by) * K];
    B = &B[BLOCK_SIZE_N * bx];

    //load index of the tile
    const int warp_id = tid / 32;
    const int lane_id = tid % 32;
    const int a_tile_index = warp_id / 2 * 16 + lane_id / 8 * 4; //warp_id * 8 + (lane_id / 16)*4; // (warp_id/4)*32 + ((lane_id%16)/2)*4;
    const int b_tile_index = warp_id % 2 * 32 + lane_id % 8 * 4; //(lane_id % 16) * 4; // (warp_id%4)*16 + (lane_id/16)*8 + (lane_id%2)*4;

    //transfer first tile from global mem to shared mem
    // load A from global memory to shared memory
#pragma unroll
    for (int i = 0; i < BLOCK_SIZE_M; i += A_TILE_ROW_STRIDE) {
        int ldg_index = i / A_TILE_ROW_STRIDE * 4;
        FETCH_FLOAT4(ldg_a_reg[ldg_index]) = FETCH_FLOAT4(A[OFFSET(
            A_TILE_ROW_START + i, // row
            A_TILE_COL, // col
            K)]);
        As[0][A_TILE_COL][A_TILE_ROW_START + i] = ldg_a_reg[ldg_index];
        As[0][A_TILE_COL + 1][A_TILE_ROW_START + i] = ldg_a_reg[ldg_index + 1];
        As[0][A_TILE_COL + 2][A_TILE_ROW_START + i] = ldg_a_reg[ldg_index + 2];
        As[0][A_TILE_COL + 3][A_TILE_ROW_START + i] = ldg_a_reg[ldg_index + 3];
    }
    // load B from global memory to shared memory
#pragma unroll
    for (int i = 0; i < BLOCK_SIZE_K; i += B_TILE_ROW_STRIDE) {
        FETCH_FLOAT4(Bs[0][B_TILE_ROW_START + i][B_TILE_COL]) = FETCH_FLOAT4(B[OFFSET(
            B_TILE_ROW_START + i, // row
            B_TILE_COL, // col
            N)]);
    }
    __syncthreads();

    // load A from shared memory to register
    FETCH_FLOAT4(frag_a[0][0]) = FETCH_FLOAT4(As[0][0][a_tile_index]);
    FETCH_FLOAT4(frag_a[0][4]) = FETCH_FLOAT4(As[0][0][a_tile_index + 64]);

    // load B from shared memory to register
    FETCH_FLOAT4(frag_b[0][0]) = FETCH_FLOAT4(Bs[0][0][b_tile_index]);
    FETCH_FLOAT4(frag_b[0][4]) = FETCH_FLOAT4(Bs[0][0][b_tile_index + 64]);

    int write_stage_idx = 1;
    int tile_idx = 0;
    do {
        // next tile index
        tile_idx += BLOCK_SIZE_K;
        // load next tile from global mem
        if (tile_idx < K) {
#pragma unroll
            for (int i = 0; i < BLOCK_SIZE_M; i += A_TILE_ROW_STRIDE) {
                int ldg_index = i / A_TILE_ROW_STRIDE * 4;
                FETCH_FLOAT4(ldg_a_reg[ldg_index]) = FETCH_FLOAT4(A[OFFSET(
                    A_TILE_ROW_START + i, // row
                    A_TILE_COL + tile_idx, // col
                    K)]);
            }
#pragma unroll
            for (int i = 0; i < BLOCK_SIZE_K; i += B_TILE_ROW_STRIDE) {
                int ldg_index = i / B_TILE_ROW_STRIDE * 4;
                FETCH_FLOAT4(ldg_b_reg[ldg_index]) = FETCH_FLOAT4(B[OFFSET(
                    tile_idx + B_TILE_ROW_START + i, // row
                    B_TILE_COL, // col
                    N)]);
            }
        }

        int load_stage_idx = write_stage_idx ^ 1;

#pragma unroll
        for (int j = 0; j < BLOCK_SIZE_K - 1; ++j) {
            // load next tile from shared mem to register 
            // load A from shared memory to register
            FETCH_FLOAT4(frag_a[(j + 1) % 2][0]) = FETCH_FLOAT4(As[load_stage_idx][(j + 1)][a_tile_index]);
            FETCH_FLOAT4(frag_a[(j + 1) % 2][4]) = FETCH_FLOAT4(As[load_stage_idx][(j + 1)][a_tile_index + 64]);
            // load B from shared memory to register
            FETCH_FLOAT4(frag_b[(j + 1) % 2][0]) = FETCH_FLOAT4(Bs[load_stage_idx][(j + 1)][b_tile_index]);
            FETCH_FLOAT4(frag_b[(j + 1) % 2][4]) = FETCH_FLOAT4(Bs[load_stage_idx][(j + 1)][b_tile_index + 64]);
            // compute C THREAD_SIZE_X x THREAD_SIZE_Y
#pragma unroll
            for (int thread_y = 0; thread_y < THREAD_SIZE_Y; ++thread_y) {
#pragma unroll
                for (int thread_x = 0; thread_x < THREAD_SIZE_X; ++thread_x) {
                    accum[thread_y][thread_x] += frag_a[j % 2][thread_y] * frag_b[j % 2][thread_x];
                }
            }
        }

        if (tile_idx < K) {
            // load A from global memory to shared memory
#pragma unroll
            for (int i = 0; i < BLOCK_SIZE_M; i += A_TILE_ROW_STRIDE) {
                int ldg_index = i / A_TILE_ROW_STRIDE * 4;
                As[write_stage_idx][A_TILE_COL][A_TILE_ROW_START + i] = ldg_a_reg[ldg_index];
                As[write_stage_idx][A_TILE_COL + 1][A_TILE_ROW_START + i] = ldg_a_reg[ldg_index + 1];
                As[write_stage_idx][A_TILE_COL + 2][A_TILE_ROW_START + i] = ldg_a_reg[ldg_index + 2];
                As[write_stage_idx][A_TILE_COL + 3][A_TILE_ROW_START + i] = ldg_a_reg[ldg_index + 3];
            }
            // load B from global memory to shared memory
#pragma unroll
            for (int i = 0; i < BLOCK_SIZE_K; i += B_TILE_ROW_STRIDE) {
                int ldg_index = i / B_TILE_ROW_STRIDE * 4;
                FETCH_FLOAT4(Bs[write_stage_idx][B_TILE_ROW_START + i][B_TILE_COL]) = FETCH_FLOAT4(ldg_b_reg[ldg_index]);
            }
            // use double buffer, only need one sync
            __syncthreads();
            // switch
            write_stage_idx ^= 1;
        }

        // load first tile from shared mem to register of next iter
        // load A from shared memory to register
        FETCH_FLOAT4(frag_a[0][0]) = FETCH_FLOAT4(As[load_stage_idx ^ 1][0][a_tile_index]);
        FETCH_FLOAT4(frag_a[0][4]) = FETCH_FLOAT4(As[load_stage_idx ^ 1][0][a_tile_index + 64]);
        // load B from shared memory to register
        FETCH_FLOAT4(frag_b[0][0]) = FETCH_FLOAT4(Bs[load_stage_idx ^ 1][0][b_tile_index]);
        FETCH_FLOAT4(frag_b[0][4]) = FETCH_FLOAT4(Bs[load_stage_idx ^ 1][0][b_tile_index + 64]);
        // compute C THREAD_SIZE_X x THREAD_SIZE_Y
#pragma unroll
        for (int thread_y = 0; thread_y < THREAD_SIZE_Y; ++thread_y) {
#pragma unroll
            for (int thread_x = 0; thread_x < THREAD_SIZE_X; ++thread_x) {
                accum[thread_y][thread_x] += frag_a[1][thread_y] * frag_b[1][thread_x];
            }
        }
    } while (tile_idx < K);

    const int c_block_row = a_tile_index;
    const int c_block_col = b_tile_index;

    //store C00 block
    for (int i = 0; i < 4; i++) {
        FETCH_FLOAT4(C[OFFSET(
            BLOCK_SIZE_M * by + c_block_row + i,
            BLOCK_SIZE_N * bx + c_block_col,
            N)]) = FETCH_FLOAT4(accum[i][0]);
    }
    //store C01 block
    for (int i = 0; i < 4; i++) {
        FETCH_FLOAT4(C[OFFSET(
            BLOCK_SIZE_M * by + c_block_row + i,
            BLOCK_SIZE_N * bx + c_block_col + 64,
            N)]) = FETCH_FLOAT4(accum[i][4]);
    }
    //store C10 block
    for (int i = 0; i < 4; i++) {
        FETCH_FLOAT4(C[OFFSET(
            BLOCK_SIZE_M * by + c_block_row + 64 + i,
            BLOCK_SIZE_N * bx + c_block_col,
            N)]) = FETCH_FLOAT4(accum[i + 4][0]);
    }
    //store C11 block
    for (int i = 0; i < 4; i++) {
        FETCH_FLOAT4(C[OFFSET(
            BLOCK_SIZE_M * by + c_block_row + 64 + i,
            BLOCK_SIZE_N * bx + c_block_col + 64,
            N)]) = FETCH_FLOAT4(accum[i + 4][4]);
    }
}

int main(int argc, char** argv) {
    
    size_t M = 512;
    size_t K = 512;
    size_t N = 512;

    assert(M % 8 == 0);
    assert(N % 8 == 0);
    assert(K % 8 == 0);

    size_t bytes_A = sizeof(float) * M * K;
    size_t bytes_B = sizeof(float) * K * N;
    size_t bytes_C = sizeof(float) * M * N;
    float* h_A = (float*)malloc(bytes_A);
    float* h_B = (float*)malloc(bytes_B);
    float* h_C = (float*)malloc(bytes_C);
    float* h_C1 = (float*)malloc(bytes_C);

    float* d_A;
    float* d_B;
    float* d_C;

    checkCudaErrors(cudaMalloc(&d_A, bytes_A));
    checkCudaErrors(cudaMalloc(&d_B, bytes_B));
    checkCudaErrors(cudaMalloc(&d_C, bytes_C));
    double msecPerMatrixMul[2] = { 0, 0 };
    double gigaFlops[2] = { 0, 0 };
    double flopsPerMatrixMul = 2.0 * M * N * K;

    // don't edit it
    const int BLOCK_SIZE_M = 128;
    const int BLOCK_SIZE_K = 8;
    const int BLOCK_SIZE_N = 128;
    const int THREAD_SIZE_X = 8;
    const int THREAD_SIZE_Y = 8;
    const bool ENABLE_DOUBLE_BUFFER = false;

    // 生成A的数据
    for (int i = 0; i < M * K; i++) {
        h_A[i] = i / 13;
    }

    // 生成B的数据
    for (int i = 0; i < K * N; i++) {
        h_B[i] = i % 13;
    }

    checkCudaErrors(cudaMemcpy(d_A, h_A, bytes_A, cudaMemcpyHostToDevice));
    checkCudaErrors(cudaMemcpy(d_B, h_B, bytes_B, cudaMemcpyHostToDevice));

    cudaEvent_t start, stop;
    checkCudaErrors(cudaEventCreate(&start));
    checkCudaErrors(cudaEventCreate(&stop));
    float msecTotal = 0;
    int nIter = 1000;

    checkCudaErrors(cudaMemcpy(d_C, h_C, bytes_C, cudaMemcpyHostToDevice));
    checkCudaErrors(cudaEventRecord(start));
    for (int run = 0; run < nIter; run++) {
        dim3 dimBlock(BLOCK_SIZE_N / THREAD_SIZE_X, BLOCK_SIZE_M / THREAD_SIZE_Y);
        dim3 dimGrid(N / BLOCK_SIZE_N, M / BLOCK_SIZE_M);
        Sgemm<BLOCK_SIZE_M, BLOCK_SIZE_K, BLOCK_SIZE_N, THREAD_SIZE_Y, THREAD_SIZE_X, ENABLE_DOUBLE_BUFFER>
            << < dimGrid, dimBlock >> > (d_A, d_B, d_C, M, N, K);
    }
    checkCudaErrors(cudaEventRecord(stop));
    checkCudaErrors(cudaEventSynchronize(stop));
    checkCudaErrors(cudaEventElapsedTime(&msecTotal, start, stop));


    checkCudaErrors(cudaMemcpy(h_C, d_C, bytes_C, cudaMemcpyDeviceToHost));

    msecPerMatrixMul[0] = msecTotal / nIter;
    gigaFlops[0] = (flopsPerMatrixMul * 1.0e-9f) / (msecPerMatrixMul[0] / 1000.0f);
    printf("My gemm Performance= %.2f GFlop/s, Time= %.3f msec, Size= %.0f Ops,\n",
        gigaFlops[0],
        msecPerMatrixMul[0],
        flopsPerMatrixMul);

    // cublas

    cublasHandle_t blas_handle;
    cublasCreate(&blas_handle);
    float alpha = 1.0;
    float beta = 0;
    checkCudaErrors(cudaMemcpy(d_C, h_C, bytes_C, cudaMemcpyHostToDevice));
    checkCudaErrors(cudaEventRecord(start));
    for (int run = 0; run < nIter; run++) {
        cublasSgemm(blas_handle, CUBLAS_OP_T, CUBLAS_OP_T,
            M, N, K, &alpha,
            d_A, K, d_B, N, &beta, d_C, N
        );
    }
    checkCudaErrors(cudaEventRecord(stop));
    checkCudaErrors(cudaEventSynchronize(stop));
    checkCudaErrors(cudaEventElapsedTime(&msecTotal, start, stop));

    checkCudaErrors(cudaMemcpy(h_C1, d_C, bytes_C, cudaMemcpyDeviceToHost));

    msecPerMatrixMul[1] = msecTotal / nIter;
    gigaFlops[1] = (flopsPerMatrixMul * 1.0e-9f) / (msecPerMatrixMul[1] / 1000.0f);
    printf("CuBlas Performance= %.2f GFlop/s, Time= %.3f msec, Size= %.0f Ops,\n",
        gigaFlops[1],
        msecPerMatrixMul[1],
        flopsPerMatrixMul);

    cublasDestroy(blas_handle);


    double eps = 1.e-6;  // machine zero
    bool correct = true;
    for (int i = 0; i < M * N; i++) {
        int row = i / N;
        int col = i % N;
        double abs_err = fabs(h_C[i] - h_C1[col * M + row]);
        double dot_length = M;
        double abs_val = fabs(h_C[i]);
        double rel_err = abs_err / abs_val / dot_length;
        if (rel_err > eps) {
            printf("Error! Matrix[%d][%d]=%.8f, ref=%.8f error term is > %E\n",
                row, col, h_C[i], h_C1[col * M + row], eps);
            correct = false;
            break;
        }
    }

    printf("%s\n", correct ? "Result= PASS" : "Result= FAIL");
    printf("ratio= %f\n", gigaFlops[0] / gigaFlops[1]);

    // Free Memory
    cudaFree(d_A);
    cudaFree(d_B);
    cudaFree(d_C);

    free(h_A);
    free(h_B);
    free(h_C);
    free(h_C1);
}
#include "cuda_runtime.h"
#include "device_launch_parameters.h"

#define BLOCK_X 16
#define BLOCK_Y 16

#define TILE_X 128
#define TILE_X_4 32
#define TILE_Y 128
#define TILE_Y_4 32

#define TILE_K 16

#define WPTN 8
#define WPTM 8
#define WPTN_4  2

__global__ void gemm_kernel_NN(
    const float* __restrict__ A,
    const float* __restrict__ B,
    float4* __restrict__ C,
    float alpha, float beta,
    int M, int N, int K)
{
    __shared__ float4 smem_a[2][TILE_K * TILE_Y_4];
    __shared__ float4 smem_b[2][TILE_K * TILE_X_4];

    int tx = threadIdx.x % 16;
    int ty = threadIdx.x / 16;

    int tx4 = threadIdx.x % 4;
    int ty4 = threadIdx.x / 4;

    int tx32 = threadIdx.x % 32;
    int ty32 = threadIdx.x / 32;

    const float* pA = (A + K * TILE_Y * blockIdx.y + ty4 * K + tx4 * 4);
    const float* pB = (B + TILE_X * blockIdx.x + ty32 * N + tx32 * 4);
    float4* pC = C + TILE_Y * blockIdx.y * N / 4 + TILE_X_4 * blockIdx.x;

    int sts_a_offset = tx4 * 4 * TILE_Y + ty4;
    int sts_b_offset = ty32 * TILE_X_4 + tx32;

    float4 f4_zero = make_float4(0.f, 0.f, 0.f, 0.f);
    bool valid_ld_a_0 = ((blockIdx.y * TILE_Y + ty4) < M) && ((tx4 * 4) < K);
    bool valid_ld_a_1 = ((blockIdx.y * TILE_Y + ty4 + 64) < M) && ((tx4 * 4) < K); 
    bool valid_ld_b_0 = ((blockIdx.x * TILE_X + tx32 * 4) < N) && (ty32 < K);
    bool valid_ld_b_1 = ((blockIdx.x * TILE_X + tx32 * 4) < N) && ((ty32 + 8) < K);

    float4 ldg_a_reg[2];
    float4 ldg_b_reg[2];

    ldg_a_reg[0] = valid_ld_a_0 ? *(const float4*)pA : f4_zero;
    ldg_a_reg[1] = valid_ld_a_1 ? *(const float4*)(pA + 64 * K) : f4_zero;
    ldg_b_reg[0] = valid_ld_b_0 ? *(const float4*)(pB + 0 * N) : f4_zero;
    ldg_b_reg[1] = valid_ld_b_1 ? *(const float4*)(pB + 8 * N) : f4_zero;

    float4 c[WPTM][WPTN_4] = { { f4_zero } };

    *((float*)&smem_a[0][0] + sts_a_offset + 0 * TILE_Y + 0) = ldg_a_reg[0].x;
    *((float*)&smem_a[0][0] + sts_a_offset + 1 * TILE_Y + 0) = ldg_a_reg[0].y;
    *((float*)&smem_a[0][0] + sts_a_offset + 2 * TILE_Y + 0) = ldg_a_reg[0].z;
    *((float*)&smem_a[0][0] + sts_a_offset + 3 * TILE_Y + 0) = ldg_a_reg[0].w;
    *((float*)&smem_a[0][0] + sts_a_offset + 0 * TILE_Y + 64) = ldg_a_reg[1].x;
    *((float*)&smem_a[0][0] + sts_a_offset + 1 * TILE_Y + 64) = ldg_a_reg[1].y;
    *((float*)&smem_a[0][0] + sts_a_offset + 2 * TILE_Y + 64) = ldg_a_reg[1].z;
    *((float*)&smem_a[0][0] + sts_a_offset + 3 * TILE_Y + 64) = ldg_a_reg[1].w;

    smem_b[0][sts_b_offset + 0] = ldg_b_reg[0];
    smem_b[0][sts_b_offset + 8 * TILE_X_4] = ldg_b_reg[1];

    __syncthreads();

    int i = 0;
    int write_stage_idx = 1;

    float4 reg_a[2][2];
    float4 reg_b[2][2];

    reg_a[0][0] = smem_a[0][0 + ty];
    reg_a[0][1] = smem_a[0][16 + ty];
    reg_b[0][0] = smem_b[0][0 + tx];
    reg_b[0][1] = smem_b[0][16 + tx];

    do
    {
        i += 16;
        valid_ld_a_0 = (valid_ld_a_0 && ((tx4 * 4 + i) < K));
        valid_ld_a_1 = (valid_ld_a_1 && ((tx4 * 4 + i) < K));
        valid_ld_b_0 = (valid_ld_b_0 && ((ty32 + i) < K));
        valid_ld_b_1 = (valid_ld_b_1 && ((ty32 + 8 + i) < K));

        ldg_a_reg[0] = (valid_ld_a_0) ? *(const float4*)(pA + i + 0) : f4_zero;
        ldg_a_reg[1] = (valid_ld_a_1) ? *(const float4*)(pA + i + 64 * K) : f4_zero;
        ldg_b_reg[0] = (valid_ld_b_0) ? *(const float4*)(pB + (i + 0) * N) : f4_zero;
        ldg_b_reg[1] = (valid_ld_b_1) ? *(const float4*)(pB + (i + 8) * N) : f4_zero;

        int load_stage_idx = write_stage_idx ^ 1;

#pragma unroll
        for (int j = 0; j < TILE_K - 1; j++)
        {
            reg_a[(j + 1) % 2][0] = smem_a[load_stage_idx][(j + 1) *  TILE_Y_4 + 0 + ty];
            reg_a[(j + 1) % 2][1] = smem_a[load_stage_idx][(j + 1) *  TILE_Y_4 + 16 + ty];
            reg_b[(j + 1) % 2][0] = smem_b[load_stage_idx][(j + 1) *  TILE_X_4 + 0 + tx];
            reg_b[(j + 1) % 2][1] = smem_b[load_stage_idx][(j + 1) *  TILE_X_4 + 16 + tx];
            c[0][0].x += reg_a[j % 2][0].x * reg_b[j % 2][0].x;
            c[0][0].y += reg_a[j % 2][0].x * reg_b[j % 2][0].y;
            c[0][0].z += reg_a[j % 2][0].x * reg_b[j % 2][0].z;
            c[0][0].w += reg_a[j % 2][0].x * reg_b[j % 2][0].w;
            c[0][1].x += reg_a[j % 2][0].x * reg_b[j % 2][1].x;
            c[0][1].y += reg_a[j % 2][0].x * reg_b[j % 2][1].y;
            c[0][1].z += reg_a[j % 2][0].x * reg_b[j % 2][1].z;
            c[0][1].w += reg_a[j % 2][0].x * reg_b[j % 2][1].w;
            c[1][0].x += reg_a[j % 2][0].y * reg_b[j % 2][0].x;
            c[1][0].y += reg_a[j % 2][0].y * reg_b[j % 2][0].y;
            c[1][0].z += reg_a[j % 2][0].y * reg_b[j % 2][0].z;
            c[1][0].w += reg_a[j % 2][0].y * reg_b[j % 2][0].w;
            c[1][1].x += reg_a[j % 2][0].y * reg_b[j % 2][1].x;
            c[1][1].y += reg_a[j % 2][0].y * reg_b[j % 2][1].y;
            c[1][1].z += reg_a[j % 2][0].y * reg_b[j % 2][1].z;
            c[1][1].w += reg_a[j % 2][0].y * reg_b[j % 2][1].w;
            c[2][0].x += reg_a[j % 2][0].z * reg_b[j % 2][0].x;
            c[2][0].y += reg_a[j % 2][0].z * reg_b[j % 2][0].y;
            c[2][0].z += reg_a[j % 2][0].z * reg_b[j % 2][0].z;
            c[2][0].w += reg_a[j % 2][0].z * reg_b[j % 2][0].w;
            c[2][1].x += reg_a[j % 2][0].z * reg_b[j % 2][1].x;
            c[2][1].y += reg_a[j % 2][0].z * reg_b[j % 2][1].y;
            c[2][1].z += reg_a[j % 2][0].z * reg_b[j % 2][1].z;
            c[2][1].w += reg_a[j % 2][0].z * reg_b[j % 2][1].w;
            c[3][0].x += reg_a[j % 2][0].w * reg_b[j % 2][0].x;
            c[3][0].y += reg_a[j % 2][0].w * reg_b[j % 2][0].y;
            c[3][0].z += reg_a[j % 2][0].w * reg_b[j % 2][0].z;
            c[3][0].w += reg_a[j % 2][0].w * reg_b[j % 2][0].w;
            c[3][1].x += reg_a[j % 2][0].w * reg_b[j % 2][1].x;
            c[3][1].y += reg_a[j % 2][0].w * reg_b[j % 2][1].y;
            c[3][1].z += reg_a[j % 2][0].w * reg_b[j % 2][1].z;
            c[3][1].w += reg_a[j % 2][0].w * reg_b[j % 2][1].w;
            c[4][0].x += reg_a[j % 2][1].x * reg_b[j % 2][0].x;
            c[4][0].y += reg_a[j % 2][1].x * reg_b[j % 2][0].y;
            c[4][0].z += reg_a[j % 2][1].x * reg_b[j % 2][0].z;
            c[4][0].w += reg_a[j % 2][1].x * reg_b[j % 2][0].w;
            c[4][1].x += reg_a[j % 2][1].x * reg_b[j % 2][1].x;
            c[4][1].y += reg_a[j % 2][1].x * reg_b[j % 2][1].y;
            c[4][1].z += reg_a[j % 2][1].x * reg_b[j % 2][1].z;
            c[4][1].w += reg_a[j % 2][1].x * reg_b[j % 2][1].w;
            c[5][0].x += reg_a[j % 2][1].y * reg_b[j % 2][0].x;
            c[5][0].y += reg_a[j % 2][1].y * reg_b[j % 2][0].y;
            c[5][0].z += reg_a[j % 2][1].y * reg_b[j % 2][0].z;
            c[5][0].w += reg_a[j % 2][1].y * reg_b[j % 2][0].w;
            c[5][1].x += reg_a[j % 2][1].y * reg_b[j % 2][1].x;
            c[5][1].y += reg_a[j % 2][1].y * reg_b[j % 2][1].y;
            c[5][1].z += reg_a[j % 2][1].y * reg_b[j % 2][1].z;
            c[5][1].w += reg_a[j % 2][1].y * reg_b[j % 2][1].w;
            c[6][0].x += reg_a[j % 2][1].z * reg_b[j % 2][0].x;
            c[6][0].y += reg_a[j % 2][1].z * reg_b[j % 2][0].y;
            c[6][0].z += reg_a[j % 2][1].z * reg_b[j % 2][0].z;
            c[6][0].w += reg_a[j % 2][1].z * reg_b[j % 2][0].w;
            c[6][1].x += reg_a[j % 2][1].z * reg_b[j % 2][1].x;
            c[6][1].y += reg_a[j % 2][1].z * reg_b[j % 2][1].y;
            c[6][1].z += reg_a[j % 2][1].z * reg_b[j % 2][1].z;
            c[6][1].w += reg_a[j % 2][1].z * reg_b[j % 2][1].w;
            c[7][0].x += reg_a[j % 2][1].w * reg_b[j % 2][0].x;
            c[7][0].y += reg_a[j % 2][1].w * reg_b[j % 2][0].y;
            c[7][0].z += reg_a[j % 2][1].w * reg_b[j % 2][0].z;
            c[7][0].w += reg_a[j % 2][1].w * reg_b[j % 2][0].w;
            c[7][1].x += reg_a[j % 2][1].w * reg_b[j % 2][1].x;
            c[7][1].y += reg_a[j % 2][1].w * reg_b[j % 2][1].y;
            c[7][1].z += reg_a[j % 2][1].w * reg_b[j % 2][1].z;
            c[7][1].w += reg_a[j % 2][1].w * reg_b[j % 2][1].w;
        }
        
        if(i < K) {
            *((float*)&smem_a[write_stage_idx][0] + sts_a_offset + 0 * TILE_Y + 0) = ldg_a_reg[0].x;
            *((float*)&smem_a[write_stage_idx][0] + sts_a_offset + 1 * TILE_Y + 0) = ldg_a_reg[0].y;
            *((float*)&smem_a[write_stage_idx][0] + sts_a_offset + 2 * TILE_Y + 0) = ldg_a_reg[0].z;
            *((float*)&smem_a[write_stage_idx][0] + sts_a_offset + 3 * TILE_Y + 0) = ldg_a_reg[0].w;
            *((float*)&smem_a[write_stage_idx][0] + sts_a_offset + 0 * TILE_Y + 64) = ldg_a_reg[1].x;
            *((float*)&smem_a[write_stage_idx][0] + sts_a_offset + 1 * TILE_Y + 64) = ldg_a_reg[1].y;
            *((float*)&smem_a[write_stage_idx][0] + sts_a_offset + 2 * TILE_Y + 64) = ldg_a_reg[1].z;
            *((float*)&smem_a[write_stage_idx][0] + sts_a_offset + 3 * TILE_Y + 64) = ldg_a_reg[1].w;

            smem_b[write_stage_idx][sts_b_offset + 0] = ldg_b_reg[0];
            smem_b[write_stage_idx][sts_b_offset + 8 * TILE_X_4] = ldg_b_reg[1];
            __syncthreads();
            write_stage_idx ^= 1;
        }

        reg_a[0][0] = smem_a[load_stage_idx ^ 1][0 + ty];
        reg_a[0][1] = smem_a[load_stage_idx ^ 1][16 + ty];
        reg_b[0][0] = smem_b[load_stage_idx ^ 1][0 + tx];
        reg_b[0][1] = smem_b[load_stage_idx ^ 1][16 + tx];

        c[0][0].x += reg_a[1][0].x * reg_b[1][0].x;
        c[0][0].y += reg_a[1][0].x * reg_b[1][0].y;
        c[0][0].z += reg_a[1][0].x * reg_b[1][0].z;
        c[0][0].w += reg_a[1][0].x * reg_b[1][0].w;
        c[0][1].x += reg_a[1][0].x * reg_b[1][1].x;
        c[0][1].y += reg_a[1][0].x * reg_b[1][1].y;
        c[0][1].z += reg_a[1][0].x * reg_b[1][1].z;
        c[0][1].w += reg_a[1][0].x * reg_b[1][1].w;
        c[1][0].x += reg_a[1][0].y * reg_b[1][0].x;
        c[1][0].y += reg_a[1][0].y * reg_b[1][0].y;
        c[1][0].z += reg_a[1][0].y * reg_b[1][0].z;
        c[1][0].w += reg_a[1][0].y * reg_b[1][0].w;
        c[1][1].x += reg_a[1][0].y * reg_b[1][1].x;
        c[1][1].y += reg_a[1][0].y * reg_b[1][1].y;
        c[1][1].z += reg_a[1][0].y * reg_b[1][1].z;
        c[1][1].w += reg_a[1][0].y * reg_b[1][1].w;
        c[2][0].x += reg_a[1][0].z * reg_b[1][0].x;
        c[2][0].y += reg_a[1][0].z * reg_b[1][0].y;
        c[2][0].z += reg_a[1][0].z * reg_b[1][0].z;
        c[2][0].w += reg_a[1][0].z * reg_b[1][0].w;
        c[2][1].x += reg_a[1][0].z * reg_b[1][1].x;
        c[2][1].y += reg_a[1][0].z * reg_b[1][1].y;
        c[2][1].z += reg_a[1][0].z * reg_b[1][1].z;
        c[2][1].w += reg_a[1][0].z * reg_b[1][1].w;
        c[3][0].x += reg_a[1][0].w * reg_b[1][0].x;
        c[3][0].y += reg_a[1][0].w * reg_b[1][0].y;
        c[3][0].z += reg_a[1][0].w * reg_b[1][0].z;
        c[3][0].w += reg_a[1][0].w * reg_b[1][0].w;
        c[3][1].x += reg_a[1][0].w * reg_b[1][1].x;
        c[3][1].y += reg_a[1][0].w * reg_b[1][1].y;
        c[3][1].z += reg_a[1][0].w * reg_b[1][1].z;
        c[3][1].w += reg_a[1][0].w * reg_b[1][1].w;
        c[4][0].x += reg_a[1][1].x * reg_b[1][0].x;
        c[4][0].y += reg_a[1][1].x * reg_b[1][0].y;
        c[4][0].z += reg_a[1][1].x * reg_b[1][0].z;
        c[4][0].w += reg_a[1][1].x * reg_b[1][0].w;
        c[4][1].x += reg_a[1][1].x * reg_b[1][1].x;
        c[4][1].y += reg_a[1][1].x * reg_b[1][1].y;
        c[4][1].z += reg_a[1][1].x * reg_b[1][1].z;
        c[4][1].w += reg_a[1][1].x * reg_b[1][1].w;
        c[5][0].x += reg_a[1][1].y * reg_b[1][0].x;
        c[5][0].y += reg_a[1][1].y * reg_b[1][0].y;
        c[5][0].z += reg_a[1][1].y * reg_b[1][0].z;
        c[5][0].w += reg_a[1][1].y * reg_b[1][0].w;
        c[5][1].x += reg_a[1][1].y * reg_b[1][1].x;
        c[5][1].y += reg_a[1][1].y * reg_b[1][1].y;
        c[5][1].z += reg_a[1][1].y * reg_b[1][1].z;
        c[5][1].w += reg_a[1][1].y * reg_b[1][1].w;
        c[6][0].x += reg_a[1][1].z * reg_b[1][0].x;
        c[6][0].y += reg_a[1][1].z * reg_b[1][0].y;
        c[6][0].z += reg_a[1][1].z * reg_b[1][0].z;
        c[6][0].w += reg_a[1][1].z * reg_b[1][0].w;
        c[6][1].x += reg_a[1][1].z * reg_b[1][1].x;
        c[6][1].y += reg_a[1][1].z * reg_b[1][1].y;
        c[6][1].z += reg_a[1][1].z * reg_b[1][1].z;
        c[6][1].w += reg_a[1][1].z * reg_b[1][1].w;
        c[7][0].x += reg_a[1][1].w * reg_b[1][0].x;
        c[7][0].y += reg_a[1][1].w * reg_b[1][0].y;
        c[7][0].z += reg_a[1][1].w * reg_b[1][0].z;
        c[7][0].w += reg_a[1][1].w * reg_b[1][0].w;
        c[7][1].x += reg_a[1][1].w * reg_b[1][1].x;
        c[7][1].y += reg_a[1][1].w * reg_b[1][1].y;
        c[7][1].z += reg_a[1][1].w * reg_b[1][1].z;
        c[7][1].w += reg_a[1][1].w * reg_b[1][1].w;
        
    } while (i < K);

#pragma unroll
    for (int wm = 0; wm < WPTM; wm++){
#pragma unroll
        for (int wn = 0; wn < WPTN_4; wn++){
            c[wm][wn].x *= alpha;
            c[wm][wn].y *= alpha;
            c[wm][wn].z *= alpha;
            c[wm][wn].w *= alpha;
        }
    }

#pragma unroll
    for (int wm = 0; wm < 4; wm++){
#pragma unroll
        for (int wn = 0; wn < WPTN_4; wn++){
            if (((blockIdx.y * TILE_Y + ty * 4 + wm) < M) 
                && ((blockIdx.x * TILE_X + wn * 64 + tx * 4) < N)) {
                if (beta != 0) {
                    float4 vec4c = *(pC + ((ty * 4 + wm) * N / 4 + wn * 16 + tx));
                    vec4c.x = vec4c.x * beta + c[wm][wn].x;
                    vec4c.y = vec4c.y * beta + c[wm][wn].y;
                    vec4c.z = vec4c.z * beta + c[wm][wn].z;
                    vec4c.w = vec4c.w * beta + c[wm][wn].w;
                    *(pC + (ty * 4 + wm) * N / 4 + wn * 16 + tx) = vec4c;
                } else {
                    *(pC + (ty * 4 + wm) * N / 4 + wn * 16 + tx) = c[wm][wn];
                }
            }
        }
    }

#pragma unroll
    for (int wm = 0; wm < 4; wm++){
#pragma unroll
        for (int wn = 0; wn < WPTN_4; wn++){
            if (((blockIdx.y * TILE_Y + 64 + ty * 4 + wm) < M) 
                && ((blockIdx.x * TILE_X + wn * 64 + tx * 4) < N)) {
                if (beta != 0) {
                    float4 vec4c = *(pC + ((64 + ty * 4 + wm) * N / 4 + wn * 16 + tx));
                    vec4c.x = vec4c.x * beta + c[wm + 4][wn].x;
                    vec4c.y = vec4c.y * beta + c[wm + 4][wn].y;
                    vec4c.z = vec4c.z * beta + c[wm + 4][wn].z;
                    vec4c.w = vec4c.w * beta + c[wm + 4][wn].w;
                    *(pC + (64 + ty * 4 + wm) * N / 4 + wn * 16 + tx) = vec4c;
                } else {
                    *(pC + (64 + ty * 4 + wm) * N / 4 + wn * 16 + tx) = c[wm + 4][wn];
                }
            }
        }
    }
}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值