虽然TVM支持基本的算术操作,但是在很多情况下,我们通常需要更复杂的内置函数.(例如exp指数函数)
这些函数依赖于目标系统,在不同的目标平台上可能有不同的名称.下面我们来学习如何调用这些目标特定的函数,以及如何通过tvm的内在API统一接口.
from __future__ import absolute_import, print_function
import tvm
from tvm import te
import numpy as np
直接声明外部数学调用
n = te.var("n")
A = te.placeholder((n,), name="A")
B = te.compute(A.shape, lambda i: tvm.tir.call_pure_extern("float32", "__expf", A[i]), name="B")
s = te.create_schedule(B.op)
num_thread = 64
bx, tx = s[B].split(B.op.axis[0], factor=num_thread)
s[B].bind(bx, te.thread_axis("blockIdx.x"))
s[B].bind(tx, te.thread_axis("threadIdx.x"))
f = tvm.build(s, [A, B], "cuda", name="myexp")
print(f.imported_modules[0].get_source())
extern "C" __global__ void myexp_kernel0(float* __restrict__ B, float* __restrict__ A, int n, int stride, int stride1) {
if (((int)blockIdx.x) < (n >> 6)) {
B[((((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride1))] = __expf(A[((((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride))]);
} else {
if (((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) < n) {
B[((((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride1))] = __expf(A[((((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride))]);
}
}
}
统一的内置函数调用
上面的代码验证了直接外部调用可以用来调用特定于设备的函数。然而,上述方法仅适用于浮点类型的CUDA目标。理想情况下,我们希望为任何设备和任何数据类型编写相同的代码。
TVM内置函数为用户提供了一种实现这个目的的机制,这是解决这类问题的推荐的方法.
n = te.var("n")
A = te.placeholder((n,), name="A")
B = te.compute(A.shape, lambda i: te.exp(A[i]), name="B")
s = te.create_schedule(B.op)
num_thread = 64
bx, tx = s[B].split(B.op.axis[0], factor=num_thread)
s[B].bind(bx, te.thread_axis("blockIdx.x"))
s[B].bind(tx, te.thread_axis("threadIdx.x"))
fcuda = tvm.build(s, [A, B], "cuda", name="myexp")
print(fcuda.imported_modules[0].get_source())
extern "C" __global__ void myexp_kernel0(float* __restrict__ B, float* __restrict__ A, int n, int stride, int stride1) {
if (((int)blockIdx.x) < (n >> 6)) {
B[((((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride1))] = __expf(A[((((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride))]);
} else {
if (((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) < n) {
B[((((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride1))] = __expf(A[((((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride))]);
}
}
}
fopencl = tvm.build(s, [A, B], "opencl", name="myexp")
print(fopencl.imported_modules[0].get_source())
__kernel void myexp_kernel0(__global float* restrict B, __global float* restrict A, int n, int stride, int stride1) {
if (((int)get_group_id(0)) < (n >> 6)) {
B[((((((int)get_group_id(0)) * 64) + ((int)get_local_id(0))) * stride1))] = exp(A[((((((int)get_group_id(0)) * 64) + ((int)get_local_id(0))) * stride))]);
} else {
if (((((int)get_group_id(0)) * 64) + ((int)get_local_id(0))) < n) {
B[((((((int)get_group_id(0)) * 64) + ((int)get_local_id(0))) * stride1))] = exp(A[((((((int)get_group_id(0)) * 64) + ((int)get_local_id(0))) * stride))]);
}
}
}
内置的底端规则(Intrinsic Lowering Rule)
当调用tvm.te.exp()时,TVM创建了一个内置的Call Expr.TVM使用转换规则把内置调用转换为指定的外部调用.
TVM也允许用户在runtime自定义规则.下面为exp自定义了一个CUDA lowering rule.
使用覆盖选项将规则注册到TVM去覆盖现有规则。
def my_cuda_math_rule(op):
assert isinstance(op, tvm.tir.Call)
name = op.op.name
assert name.startswith("tir.")
dispatch_name = name[4:]
if op.dtype == "float32":
return tvm.tir.call_pure_extern("float32", "%sf" % dispatch_name, op.args[0])
elif op.dtype == "float64":
return tvm.tir.call_pure_extern("float64", dispatch_name, op.args[0])
else:
return op
tvm.target.register_intrin_rule("cuda", "exp", my_cuda_math_rule, override=True)
<tvm.runtime.packed_func.PackedFunc at 0x7f193843dd00>
fcuda = tvm.build(s, [A, B], "cuda", name="myexp")
print(fcuda.imported_modules[0].get_source())
extern "C" __global__ void myexp_kernel0(float* __restrict__ B, float* __restrict__ A, int n, int stride, int stride1) {
if (((int)blockIdx.x) < (n >> 6)) {
B[((((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride1))] = expf(A[((((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride))]);
} else {
if (((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) < n) {
B[((((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride1))] = expf(A[((((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride))]);
}
}
}
fopencl = tvm.build(s, [A, B], "opencl", name="myexp")
print(fopencl.imported_modules[0].get_source())
__kernel void myexp_kernel0(__global float* restrict B, __global float* restrict A, int n, int stride, int stride1) {
if (((int)get_group_id(0)) < (n >> 6)) {
B[((((((int)get_group_id(0)) * 64) + ((int)get_local_id(0))) * stride1))] = exp(A[((((((int)get_group_id(0)) * 64) + ((int)get_local_id(0))) * stride))]);
} else {
if (((((int)get_group_id(0)) * 64) + ((int)get_local_id(0))) < n) {
B[((((((int)get_group_id(0)) * 64) + ((int)get_local_id(0))) * stride1))] = exp(A[((((((int)get_group_id(0)) * 64) + ((int)get_local_id(0))) * stride))]);
}
}
}
添加自己的内置函数
def mylog(x):
return tvm.tir.call_intrin(x.dtype, "tir.mylog", x)
def my_cuda_mylog_rule(op):
if op.dtype == "float32":
return tvm.tir.call_pure_extern("float32", "logf", op.args[0])
elif op.dtype == "float64":
return tvm.tir.call_pure_extern("float64", "log", op.args[0])
else:
return op
tvm.ir.register_op_attr("tir.mylog", "TCallEffectKind", tvm.tir.CallEffectKind.Pure)
tvm.target.register_intrin_rule("cuda", "mylog", my_cuda_mylog_rule, override=True)
<tvm.runtime.packed_func.PackedFunc at 0x7f1940674d00>
n = te.var("n")
A = te.placeholder((n,), name="A")
B = te.compute(A.shape, lambda i: mylog(A[i]), name="B")
s = te.create_schedule(B.op)
num_thread = 64
bx, tx = s[B].split(B.op.axis[0], factor=num_thread)
s[B].bind(bx, te.thread_axis("blockIdx.x"))
s[B].bind(tx, te.thread_axis("threadIdx.x"))
fcuda = tvm.build(s, [A, B], "cuda", name = "mylog")
print(fcuda.imported_modules[0].get_source())
extern "C" __global__ void mylog_kernel0(float* __restrict__ B, float* __restrict__ A, int n, int stride, int stride1) {
if (((int)blockIdx.x) < (n >> 6)) {
B[((((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride1))] = logf(A[((((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride))]);
} else {
if (((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) < n) {
B[((((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride1))] = logf(A[((((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride))]);
}
}
}
fopencl = tvm.build(s, [A, B], "opencl", name="myexp")
---------------------------------------------------------------------------
TVMError Traceback (most recent call last)
<ipython-input-13-777719bed849> in <module>
----> 1 fopencl = tvm.build(s, [A, B], "opencl", name="myexp")
~/tvm/python/tvm/driver/build_module.py in build(inputs, args, target, target_host, name, binds)
414 device_modules = []
415 for tar, input_mod in target_input_mod.items():
--> 416 mod_host, mdev = _build_for_device(input_mod, tar, target_host)
417 mod_host_all.update(mod_host)
418 device_modules.append(mdev)
~/tvm/python/tvm/driver/build_module.py in _build_for_device(input_mod, target, target_host)
295 )
296
--> 297 rt_mod_dev = codegen.build_module(mod_dev, target) if len(mod_dev.functions) != 0 else None
298 return mod_host, rt_mod_dev
299
~/tvm/python/tvm/target/codegen.py in build_module(mod, target)
37 """
38 target = Target(target) if isinstance(target, str) else target
---> 39 return _ffi_api.Build(mod, target)
40
41
~/tvm/python/tvm/_ffi/_ctypes/packed_func.py in __call__(self, *args)
235 != 0
236 ):
--> 237 raise get_last_ffi_error()
238 _ = temp_args
239 _ = args
TVMError: Traceback (most recent call last):
[bt] (8) /home/liu/tvm/build/libtvm.so(tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)+0x265) [0x7f1915b77b85]
[bt] (7) /home/liu/tvm/build/libtvm.so(tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::IfThenElseNode const*)+0xb9) [0x7f191628cf29]
[bt] (6) /home/liu/tvm/build/libtvm.so(tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)+0x265) [0x7f1915b77b85]
[bt] (5) /home/liu/tvm/build/libtvm.so(tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::StoreNode const*)+0x4ac) [0x7f191628e47c]
[bt] (4) /home/liu/tvm/build/libtvm.so(tvm::codegen::CodeGenC::PrintExpr[abi:cxx11](tvm::PrimExpr const&)+0x19f) [0x7f191629a9bf]
[bt] (3) /home/liu/tvm/build/libtvm.so(tvm::codegen::CodeGenC::PrintExpr(tvm::PrimExpr const&, std::ostream&)+0x95) [0x7f191628a2f5]
[bt] (2) /home/liu/tvm/build/libtvm.so(tvm::codegen::CodeGenOpenCL::VisitExpr_(tvm::tir::CallNode const*, std::ostream&)+0x5b) [0x7f19162b89eb]
[bt] (1) /home/liu/tvm/build/libtvm.so(tvm::codegen::CodeGenC::VisitExpr_(tvm::tir::CallNode const*, std::ostream&)+0x6bf) [0x7f1916290e0f]
[bt] (0) /home/liu/tvm/build/libtvm.so(+0xd28a18) [0x7f1916287a18]
File "/home/liu/tvm/src/target/source/codegen_c.cc", line 652
TVMError: Unresolved call Op(tir.mylog)