Thrust自定义函数非常灵活,给编写程序带来方便,上次的自定义函数没有带参数,这次
笔者写一个带参数的自定义函数。通过这次写带参数的自定义函数,我对如何用transformation
调用函数有了更深的理解。
代码:
#include <thrust/device_vector.h>
#include <thrust/transform.h>
#include <thrust/sequence.h>
#include <thrust/copy.h>
#include <thrust/fill.h>
#include <thrust/replace.h>
#include <thrust/functional.h>
#include <iostream>
#define D_SW 1
struct saxpy_functor
{
const float a;
saxpy_functor(float _a) : a(_a) {}
__host__ __device__
float operator()(const float& x, const float& y) const {
return a * x + y;
}
};
void saxpy_fast(float A, thrust::device_vector<float>&X,
thrust::device_vector<float>&Y) {
//Y<- A*X+Y
thrust::transform(X.begin(), X.end(), Y.begin(), Y.begin(),
saxpy_functor(A));
}
void saxpy_slow(float A, thrust::device_vector<float>&X,
thrust::device_vector<float>&Y) {
thrust::device_vector<float>temp(X.size());
//temp=A
thrust::fill(temp.begin(), temp.end(),A);
//temp=A*X
thrust::transform(X.begin(), X.end(), temp.begin(), temp.begin(),
thrust::multiplies<float>());
//Y=A*X+Y
thrust::transform(temp.begin(), temp.end(), Y.begin(), Y.begin(),
thrust::plus<float>());
}
int main(void) {
float a = -1.0f;
thrust::device_vector<float>x(10);
thrust::device_vector<float>y(10);
thrust::sequence(x.begin(), x.end());
thrust::sequence(y.begin(), y.end());
/*for (int i = 0; i < y.size(); ++i) {
std::cout << y[i] << std::endl;
}*/
#if D_SW
saxpy_fast(a, x, y);
#else
saxpy_slow(a, x, y);
#endif
for (int i = 0; i < y.size(); ++i) {
std::cout << y[i] << std::endl;
}
return 0;
}