模板函数和模板类
一、模板函数
模板函数如果在头文件中声明则一般要在头文件实现,如果像普通函数一样在原文件中实现可能会出现错误,找不到链接什么的。
正确范例:
头文件中声明(.hpp文件中声明),案例如下:
template <typename Ftype>
cudaError_t Forward_gpu(const int count, const int channels, const int dim,
const Ftype *mDeviceKernel,
const Ftype *bottom_data, Ftype *top_data,
const Ftype zero,
const int div_factor,
cudaStream_t stream);
源文件.cu(此处为cuda编程),源文件实现之后还需要实例化声明(在文件末尾)
// CUDA: use 512 threads per block
const int CAFFE_CUDA_NUM_THREADS = 512;
// CUDA: number of blocks for threads.
inline int CAFFE_GET_BLOCKS(const int N) {
return (N + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS;
}
// CUDA: grid stride looping
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)
/******** PReLU CUDA function ********/
// CUDA kernele for forward
template <typename Ftype>
__global__ void PReLUForward(const int n, const int channels, const int dim,
const Ftype* slope_data,
const Ftype* in, Ftype* out,
const Ftype zero,
const int div_factor) {
CUDA_KERNEL_LOOP(index, n) {
int c = (index / dim) % channels / div_factor;
//You do that just for the half precision,while the orginal caffe's implementation is just for float or double type
out[index] = (in[index] > (Ftype(zero))) ? in[index] : in[index] * *(reinterpret_cast<const Ftype*>(slope_data)+c);
//out[index] = (in[index] > 0) ? in[index] : in[index] * slope_data[c];
}
}
template <typename Ftype>
cudaError_t Forward_gpu(const int count, const int channels, const int dim,
const Ftype* mDeviceKernel,
const Ftype* bottom_data, Ftype* top_data,
const Ftype zero,
const int div_factor, const cudaStream_t stream) {
PReLUForward<<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS, 0, stream>>>
(count, channels, dim, mDeviceKernel, bottom_data, top_data, zero, div_factor);
cudaError_t err = cudaGetLastError();
return err;
}
//源文件实现之后还需要这样的声明(在文件末尾)
// function instantiation
// https://courses.cs.washington.edu/courses/cse326/02wi/computing/c++-templates.html
template cudaError_t Forward_gpu<float>(const int count, const int channals, const int dim,
const float* mDeviceKernel,
const float* bottom_data, float* top_data,
const float zero,
const int div_factor,
const cudaStream_t stream);
template cudaError_t Forward_gpu<__half>(const int count, const int channals, const int dim,
const __half* mDeviceKernel,
const __half* bottom_data, __half* top_data,
const __half zero,
const int div_factor,
const cudaStream_t stream);
二、模板类
1、在头文件中声明与实现模板类
例如在Test.hpp中声明类并实现
template< typename T >
class Example
{
public:
void SetValue( const T& newValue );
private:
T m_value;
};
template< typename T >
void
Example< T >::SetValue( const T& newValue )
{
m_value = newValue;
}
2、头文件中声明,源文件中实现,但是记得要将模板类实例化,否再出现LNK2001错误。
(1)定义宏函数(参考自caffe源码实现以及其他框架)
// Instantiate a class with float and double specifications.
#define INSTANTIATE_CLASS(classname) \
char gInstantiationGuard##classname; \
template class classname<float>; \
template class classname<double> //还可以继续添加支持得类型。
(2)案例
可参看caffe源码中模板类的实现代码,比如Blob. 这里还是举一个简单的案例
//Test.hpp中:
template< typename T >
class Example
{
public:
void SetValue( const T& newValue );
private:
T m_value;
};
//Test.cpp中
template< typename T >
void
Example< T >::SetValue( const T& newValue )
{
m_value = newValue;
}
INSTANTIATE_CLASS(Example) //这一句比较关键,否则会报LNK2001错误。