ex2 可以对矩阵数据的每行进行并行排序
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <thrust/random.h>
#include <thrust/iterator/counting_iterator.h>
#include <iostream>
#include <iterator>
//compile: nvcc cuda_ex3.cu -o a.out
//profiling: nvprof ./a.out
template<typename T>
void print_array(const thrust::host_vector<T>& v, std::string name = "")
{
std::cout << name << "\n";
thrust::copy(v.begin(), v.end(), std::ostream_iterator<decltype(*v.begin())>(std::cout, "\n"));
return;
}
int my_rand(void)
{
static thrust::default_random_engine rng;
static thrust::uniform_int_distribution<int> dist(0, 9999);
return dist(rng);
}
/***********************************
int data[numRows][numElem];
对每一行数据data[numRows][*]进行排序
************************************/
static const int numRows=60*10000;
static const int numElem=100;
static const int TotalSize=numRows*numElem;
struct MyFunctor{
thrust::device_ptr<int> m_data;
__host__ __device__ void operator()(int row)
{
//row sort on gpu
auto begin=m_data+row*numElem;
thrust::sort(thrust::cuda::par,begin,begin+numElem);
}
};
void test()
{
/***********************************
generate random test data on cpu
*************************************/
thrust::host_vector<int> h_x(TotalSize);
thrust::generate(h_x.begin(),h_x.end(),my_rand);
thrust::device_vector<int> d_x=h_x;
/**********************************
sort row by row on gpu
***********************************/
MyFunctor f{d_x.data()};
thrust::counting_iterator<int> begin_row(0);
thrust::counting_iterator<int> end_row(numRows);
thrust::for_each(begin_row,end_row,f);
/************************************
check result
*************************************/
h_x=d_x;
//thrust::copy(h_x.begin()+TotalSize-numElem, h_x.end(), std::ostream_iterator<decltype(*h_x.begin())>(std::cout, "\n"));
bool all_sorted=true;
for(int row=0; row<numRows; row++)
{
auto begin=h_x.begin()+row*numElem;
if(!thrust::is_sorted(begin,begin+numElem))
{
all_sorted=false;
break;
}
}
std::cout<<"check all-row-sorted status: "<<(all_sorted?"success":"fail")<<"\n";
}
int main(int argc, char **argv)
{
test();
return(0);
}