前文CUDA的并行规约算法的示意图如下,分析可知,相邻之间的线程执行不同的路径,存在线程束分化。
为了使得线程束不存在分化,每个warp(32个线程)执行同一指令,可调整相邻的线程的数组索引实现优化。示意图如下图所示,数组的存储位置没变,只是没个线程执行的数组发生了变化,这样的处理模式可以降低相邻线程分化降低,尽早释放后面的线程。
实验在GTX1050Ti进行,线程块长度为1024,性能提升1.73倍左右,但随着线程块的减少,性能提升也有所降低,这主要和warp size有关~.代码如下:
#include "cuda_runtime.h"
#include "device_launch_parameters.h"
#include <stdio.h>
#include "math.h"
#include "stdlib.h"
//错误检查的宏定义
#define CHECK(call) \
{ \
const cudaError_t status=call; \
if (status!=cudaSuccess) \
{ \
printf("文件:%s,函数:%s,行号:%d",__FILE__, \
__FUNCTION__,__LINE__); \
printf("%s", cudaGetErrorString(status)); \
exit(1); \
} \
} \
//核函数
__global__ void Kernel(int *d_data, int *d_local_sum, int N)
{
int tid = threadIdx.x;
int index = blockIdx.x*blockDim.x + threadIdx.x;
int *data = d_data + blockIdx.x*blockDim.x;
if (index >= N) return;
for (int strize = 1; strize < blockDim.x; strize *= 2)
{
int idx = tid*strize * 2;
if (idx < blockDim.x)
data[idx]+= data[idx+strize];
__syncthreads();
}
if (tid == 0)
{
d_local_sum[blockIdx.x] = data[0];
}
}
//主函数
int main()
{
//基本参数设置
cudaSetDevice(0);
const int N = 65536;
int local_length =1024;
int total_sum = 0;
dim3 grid(((N + local_length - 1) / local_length), 1);
dim3 block(local_length, 1);
int *h_data = nullptr;
int *h_local_sum = nullptr;
int *d_data = nullptr;
int *d_local_sum = nullptr;
//Host&Deivce内存申请及数组初始化
h_data = (int*)malloc(N * sizeof(int));
h_local_sum = (int*)malloc(int(grid.x) * sizeof(int));
CHECK(cudaMalloc((void**)&d_data, N * sizeof(int)));
CHECK(cudaMalloc((void**)&d_local_sum, int(grid.x) * sizeof(int)));
for (int i = 0; i < N; i++)
h_data[i] = int(10 * sin(0.02*3.14*i));//限制数组元素值,防止最终求和值超过int的范围
//数据拷贝至Device
CHECK(cudaMemcpy(d_data, h_data, N * sizeof(int), cudaMemcpyHostToDevice));
//for (int i=0;i<200;i++)
//执行核函数
Kernel << <grid, block >> > (d_data, d_local_sum, N);
//数据拷贝至Host
CHECK(cudaMemcpy(h_local_sum, d_local_sum, int(grid.x) * sizeof(int),
cudaMemcpyDeviceToHost));
//同步&重置设备
CHECK(cudaDeviceSynchronize());
CHECK(cudaDeviceReset());
for (int i = 0; i < int(grid.x); i++)
{
total_sum += h_local_sum[i];
}
printf("%d \n", total_sum);
//getchar();
return 0;
}