为什么
不同硬件后端厂商提供不同的库,使得用户需要学习很多不同的编程范式。比如MKLDNN,cuDNN,TensorRT,因此需要一个统一的编程接口去调用这些库,从而让TVM可以正确依赖于这些后端库生成对应的代码。
是什么
- let all users and hardware backend providers stand on the same page
- provide a feasible solution to allow specialized hardware or library to only support widely used operators with extremely high performance, but fallback unsupported operators to general devices like CPU/GPU.
TVM Codegen是一个面向硬件的统一化的编程接口。
原理
使用字符串拼接生成代码。
Codegen继承自ExprVisitor和CodegenBase。前者提供遍历计算图的能力并生成子图函数,后者定义了Codegen的生成规则。生成代码时,从VisitExpr_里面拿节点信息(比如CallNode)。
例子
现需要生成这样的算子以及函数调用:
#define CSOURCE_BINARY_OP_1D(p_ID_, p_OP_, p_DIM1_) \
extern "C" void p_ID_(float* a, float* b, float* out) { \
for (int64_t i = 0; i < p_DIM1_; ++i) { \
out[i] = a[i] p_OP_ b[i]; \
} \
}
#define CSOURCE_BINARY_OP_2D(p_ID_, p_OP_, p_DIM1_, p_DIM2_) \
extern "C" void p_ID_(float* a, float* b, float* out) { \
for (int64_t i = 0; i < p_DIM1_; ++i) { \
for (int64_t j = 0; j < p_DIM2_; ++j) { \
int64_t k = i * p_DIM2_ + j; \
out[k] = a[k] p_OP_ b[k]; \
} \
} \
}
// Note 1
GCC_BINARY_OP_2D(gcc_0_0, *, 10, 10);
GCC_BINARY_OP_2D(gcc_0_1, -, 10, 10);
GCC_BINARY_OP_2D(gcc_0_2, +, 10, 10);
// Note 2
extern "C" void gcc_0_(float* gcc_input0, float* gcc_input1,
float* gcc_input2, float* gcc_input3, float* out) {
float* buf_0 = (float*)malloc(4 * 100);
float* buf_1 = (float*)malloc(4 * 100);
gcc_0_2(gcc_input0, gcc_input1, buf_0);
gcc_0_1(buf_0, gcc_input2, buf_1);
gcc_0_0(buf_1, gcc_input3, out);
free(buf_0);
free(buf_1);
}
// Note 3
extern "C" int gcc_0_wrapper(DLTensor* arg0, DLTensor* arg1, DLTensor* arg2,
DLTensor* arg3, DLTensor* out) {
gcc_0_(static_cast<float*>(arg0->data), static_cast<float*>(arg1->data),
static_cast<float*>(arg2->data), static_cast<float*>(arg3->data),
static_cast<float*>(out->data));
return 0;
}
TVM_DLL_EXPORT_TYPED_FUNC(gcc_0, gcc_0_wrapper);
要做的是通过字符串拼接实现代码生成。
需要生成:
- 函数声明
- 函数调用
- temp buffer(临时储存计算结果,之后free掉)
- output buffer(保存函数调用的结果)
- 调用所有函数并传入buffer,得到输出。
(以上由CodegenC类实现) - 函数定义(涉及到具体的函数计算逻辑)
(单独由CSourceCodegen类实现)
生成函数声明
GCC_BINARY_OP_2D(gcc_0_0, *, 10, 10);
有如下代码片段:src/relay/backend/contrib/codegen_c/codegen.cc
std::ostringstream macro_stream;
std::ostringstream decl_stream;
std::ostringstream buf_stream;
macro_stream << "CSOURCE_BINARY_OP_" << call->args.size() << "D(" << func_name << ", ";
if (IsOp(call, "add")) {
macro_stream << "+";
} else if (IsOp(call, "subtract")) {
macro_stream << "-";
} else if (IsOp(call, "multiply")) {
macro_stream << "*";
} else {
LOG(FATAL) << "Unrecognized op";
}
...
func_decl_.push_back(macro_stream.str());
最后将输出流变成str储存到func_decl_容器中。对每个CallNode都这样操作,最后得到所有所需的函数声明。函数的定义则由CSourceCodegen实现。