1 、乘法
#include <cstdio>
#include <iostream>
#include <chrono>
#define TILE_WIDTH 32
__global__ void MatrixMulKernel(float* md, float* nd, float* pd, int width) {
__shared__ float mds[TILE_WIDTH][TILE_WIDTH];
__shared__ float nds[TILE_WIDTH][TILE_WIDTH];
int bx = blockIdx.x;
int by = blockIdx.y;
int tx = threadIdx.x;
int ty = threadIdx.y;
int row = by * TILE_WIDTH + ty;
int col = bx * TILE_WIDTH + tx;
float p_value = 0;
for (int i = 0; i < width/TILE_WIDTH; i++) {
mds[ty][tx] = md[row * width + (i*TILE_WIDTH + tx)];
nds[ty][tx] = nd[(i*TILE_WIDTH + ty)*width + col];
__syncthreads();
for (int j = 0; j < TILE_WIDTH; j++) {
p_value += mds[ty][j] * nds[j][tx];
}
__syncthreads();
}
pd[row*width + col] = p_value;
}
int main() {
......
MatrixMulKernel<<<dim3(width/TILE_WIDTH, width/TILE_WIDTH), dim3(TILE_WIDTH, TILE_WIDTH)>>>(md, nd, pd, width);
......
}