一 杂计
1.1 基础代码
template<typename T>
T add(T a, T b)
{
return a + b;
}
template int add<>(int, int); // 显式实例化
int main()
{
add<int>(1, 2); // 显式指定模板参数 T
add(1, 2); // 自动推导模板参数 T
add(1.0f, 2.0f); // 自动推导并且隐式实例化
}
模板是本质上是代码生成器
1.2 参数中有函数
template <typename T, void (*OpFunc)(T *, T *, T *, T *, int, int, float)>
__mlu_func__ void block3Unary(T *x, T *y, char *nram_buffer, int32_t num_total,
int32_t offset_x_half, int32_t offset_aux_a,
int32_t offset_aux_b, int32_t num_deal,
int32_t num_pong, float coef) {
if (__is_mpu()) {
return;
}
int32_t num_per_core = num_total / taskDim;
int32_t num_rem = num_total % taskDim;
T *addr_x = (T *)x + taskId * num_per_core;
T *addr_y = (T *)y + taskId * num_per_core;
if (num_rem > 0 && taskId == taskDim - 1) {
num_per_core = num_per_core + num_rem;
}
int32_t repeat = num_per_core / num_deal;
int32_t rem = num_per_core % num_deal;
int32_t align_rem = CEIL_ALIGN(rem, UNARY_ALIGN_NUM);
T *nram_x = (T *)nram_buffer;
T *nram_x_half = (T *)nram_buffer + offset_x_half;
T *nram_aux_a = (T *)nram_buffer + offset_aux_a;
T *nram_aux_b = (T *)nram_buffer + offset_aux_b;
int32_t span_handle_size = num_deal * sizeof(T);
// 3 level pipeline.
if (repeat > 0) {
__memcpy_async(nram_x_half, addr_x, span_handle_size, GDRAM2NRAM);
__sync();
}
if (repeat > 1) {
__memcpy_async(nram_x_half + num_pong, addr_x + num_deal, span_handle_size,
GDRAM2NRAM);
OpFunc(nram_x, nram_x_half, nram_aux_a, nram_aux_b, num_deal, num_deal,
coef);
__sync();
}
在使用的模板时候,需要传入类型参数,和函数的类型