五维度数组CUDA

#include <iostream>
#include <assert.h>

template<typename T>
T***** create_5d_flat(int a, int b, int c, int d,int e) {
	T *base;
	cudaError_t err = cudaMallocManaged(&base, a*b*c*d*e * sizeof(T));
	assert(err == cudaSuccess);
	T *****ary;
	err = cudaMallocManaged(&ary, (a + a * b + a * b*c + a * b*c*d) * sizeof(T*));
	assert(err == cudaSuccess);
	for (int i = 0; i < a; i++) 
	{
		ary[i] = (T ****)((ary + a) + i * b);
		for (int j = 0; j < b; j++) 
		{
			ary[i][j] = (T ***)((ary + a + a * b) + (i * b + j)* c);
			for (int k = 0; k < c; k++)
			{
				ary[i][j][k] = (T **)((ary + a + a * b+a*b*c) + ((i * b + j)* c+k)*d      );
				for (int l = 0; l < d; l++)
					ary[i][j][k][l] = base + (((i*b + j)*c + k)*d+l)*e;
			}

		}
	}
	return ary;
}

template<typename T>
void free_5d_flat(T***** ary) {
	if (ary[0][0][0][0]) cudaFree(ary[0][0][0][0]);
	if (ary) cudaFree(ary);
}


template<typename T>
__global__ void fill(T***** data, int a, int b, int c, int d, int e) {
	unsigned long long int val = 0;
	for (int i = 0; i < a; i++)
		for (int j = 0; j < b; j++)
			for (int k = 0; k < c; k++)
				for (int l = 0; l < d; l++)
					for (int m = 0; m < e; m++)
						data[i][j][k][l][m] = val++;
}

void report_gpu_mem()
{
	size_t free, total;
	cudaMemGetInfo(&free, &total);
	std::cout << "Free = " << free << " Total = " << total << std::endl;
}

int main() {
	report_gpu_mem();

	unsigned long long int *****data2;
	std::cout << "allocating..." << std::endl;
	data2 = create_5d_flat<unsigned long long int>(6, 9, 3, 5,4);

	report_gpu_mem();

	fill << <1, 1 >> > (data2, 6, 9, 3, 5, 4);
	cudaError_t err = cudaDeviceSynchronize();
	assert(err == cudaSuccess);






	std::cout << "validating..." << std::endl;




	//std::cout << *(data2[0] ) << std::endl;
	//std::cout << *data2[1] << std::endl;
	//std::cout << &data2[1][0][0][0] << std::endl;

	for (int i = 0; i < 6; i++)
		for (int j = 0; j < 9; j++)
			for (int k = 0; k < 3; k++)
				for (int l = 0; l < 5; l++)
					for (int m = 0; m < 4; m++)
						std::cout << data2[i][j][k][l][m] << std::endl;


	for (int i = 0; i < 6* 9* 3* 5* 4; i++)
		if (*(data2[0][0][0][0] + i) != i) 
		{
			std::cout << "mismatch at " << i << " was " << *(data2[0][0][0][0] + i) << std::endl;
			return -1;
		}

	free_5d_flat(data2);
	return 0;
}

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值