1. kernel装饰器
之后我们看一下@ti.kernel
装饰器
def _kernel_impl(_func, level_of_class_stackframe, verbose=False):
is_classkernel = _inside_class(level_of_class_stackframe + 1)
if verbose:
print(f'kernel={_func.__name__} is_classkernel={is_classkernel}')
primal = Kernel(_func,
autodiff_mode=AutodiffMode.NONE,
_classkernel=is_classkernel)
adjoint = Kernel(_func,
autodiff_mode=AutodiffMode.REVERSE,
_classkernel=is_classkernel)
primal.grad = adjoint
if is_classkernel:
@functools.wraps(_func)
def wrapped(*args, **kwargs):
clsobj = type(args[0])
assert not hasattr(clsobj, '_data_oriented')
raise TaichiSyntaxError(
f'Please decorate class {clsobj.__name__} with @ti.data_oriented'
)
else:
@functools.wraps(_func)
def wrapped(*args, **kwargs):
try:
return primal(*args, **kwargs)
except (TaichiCompilationError, TaichiRuntimeError) as e:
raise type(e)('\n' + str(e)) from None
wrapped.grad = adjoint
wrapped._is_wrapped_kernel = True
wrapped._is_classkernel = is_classkernel
wrapped._primal = primal
wrapped._adjoint = adjoint
return wrapped
def kernel(fn):
return _kernel_impl(fn, level_of_class_stackframe=3)
kernel
是一个装饰器函数,其内部调用了_kernel_impl
作为具体实现。这里首先判断了当前的kernel是否在一个类中,是的,taichi允许你在一个类中定义kernel,这给了开发者极大的便利,同时taichi用了非常巧妙的方式来处理在类中的Kernel,这里我们暂时不讨论,我先关注于在类外面的kernel函数。而_inside_class
方法具体是如何判断是否在类中的,则用了我们的老朋友,inspect
模块来获取解释器堆栈上的语句信息,通过正则匹配检测class关键字实现的。
接下来实例化了两个Kernel类,这个Kernel类是taichi程序的编译核心,也是执行入口,具体编译过程在Kernel的__call__
魔法函数中。
之后是一个内部闭包wrapped,functools.wrap装饰器的作用是不改变调用对象,单纯改变所装饰对象的一些属性,也就是这里最后返回的是wrapped函数,但是属性则依然是传入的play函数,所以我们打印play函数的__name__
和__doc__
属性就会发现没有发生改变,但是具体调用逻辑则是调用Kernel的__call__
函数。
那我们接下来就很好看一下Kernel这个类:
class Kernel:
counter = 0
def __init__(self, _func, autodiff_mode, _classkernel=False):
self.func = _func
self.kernel_counter = Kernel.counter
Kernel.counter += 1
assert autodiff_mode in (AutodiffMode.NONE, AutodiffMode.VALIDATION,
AutodiffMode.FORWARD, AutodiffMode.REVERSE)
self.autodiff_mode = autodiff_mode
self.grad = None
self.arguments = []
self.return_type = None
self.classkernel = _classkernel
self.extract_arguments()
self.template_slot_locations = []
for i, arg in enumerate(self.arguments):
if isinstance(arg.annotation, template):
self.template_slot_locations.append(i)
self.mapper = TaichiCallableTemplateMapper(
self.arguments, self.template_slot_locations)
impl.get_runtime().kernels.append(self)
self.reset()
self.kernel_cpp = None
# TODO[#5114]: get rid of compiled_functions and use compiled_kernels instead.
# Main motivation is that compiled_kernels can be potentially serialized in the AOT scenario.
self.compiled_kernels = {}
self.has_print = False
在初始化Kernel时,把所装饰的函数作为了当前func属性,获取当前kernel的计数器,之后静态成员变量counter+1。判断是否开启auto_diff,是否有自动微分等。之后通过extract_argument来获取函数形参参数:
def extract_arguments(self):
sig = inspect.signature(self.func)
if sig.return_annotation not in (inspect._empty, None):
self.return_type = sig.return_annotation
params = sig.parameters
arg_names = params.keys()
for i, arg_name in enumerate(arg_names):
param = params[arg_name]
if param.kind == inspect.Parameter.VAR_KEYWORD:
raise TaichiSyntaxError(
'Taichi kernels do not support variable keyword parameters (i.e., **kwargs)'
)
......
annotation = param.annotation
if param.annotation is inspect.Parameter.empty:
if i == 0 and self.classkernel: # The |self| parameter
annotation = template()
else:
raise TaichiSyntaxError(
'Taichi kernels parameters must be type annotated')
......
self.arguments.append(
KernelArgument(annotation, param.name, param.default))
获取的方式当然又是我们的老熟人inspect模块,通过signature获取到func上的形参和返回值类型属性,之后是一大段参数类型的异常判断,在这里我们看到关键字参数,带初始值的参数都是不支持的。之后将这些参数加入到Kernel类的成员变量arguments中,Taichi在这里对参数又进行了一层包装,不过KernelArgument就是一个非常简单的data类了:
class KernelArgument:
def __init__(self, _annotation, _name, _default=inspect.Parameter.empty):
self.annotation = _annotation
self.name = _name
self.default = _default
让我们回到Kernel的构造函数,获取到函数参数后对这些形参进行了遍历,判断是否是Template类型参数,这里的Template是一个用Python定义的一个Data类,用于taichi的元编程使用,taichi的模板编程实现的方式我们也暂且不表。再往下是初始化了一个TaichiCallbleTemplateMapper
这个map存储了参数了信息为后续调用时使用。随后将现在的kernel加入到我们之前提到的pytaichi全局变量中,那是一个Program类,存储了全部的kernel。再之后就是调用了reset函数:
def reset(self):
self.runtime = impl.get_runtime()
reset函数非常简单,把初始化当前runtime成员变量的方法提取出来以供复用,这里runtime所赋值的依然是pytaichi。到此Kernel的初始化工作基本完成了。
Taichi所采用的对Python实施静态编译的策略与我们之前探讨的pytorch和qcor有所不同,他并不是在装饰器加载时就完成了jit编译,而是在第一次调用时才进行编译,下面我们就探究一下Kernel的__call__
函数:
2. kernel的运行入口
@_shell_pop_print
def __call__(self, *args, **kwargs):
args = _process_args(self, args, kwargs)
if self.runtime.fwd_mode_manager and not self.runtime.grad_replaced:
self.runtime.fwd_mode_manager.insert(self)
if self.autodiff_mode in (
AutodiffMode.NONE, AutodiffMode.VALIDATION
) and self.runtime.target_tape and not self.runtime.grad_replaced:
self.runtime.target_tape.insert(self, args)
if self.autodiff_mode != AutodiffMode.NONE and impl.current_cfg(
).opt_level == 0:
_logging.warn(
"""opt_level = 1 is enforced to enable gradient computation."""
)
impl.current_cfg().opt_level = 1
key = self.ensure_compiled(*args)
return self.runtime.compiled_functions[key](*args)
这段程序的代码量并不多,首先_shell_pop_print
装饰器在绝大多数情况下冰不起作用,不会改变调用主体,只有在开启pybuf时会有一个额外处理,但依然会调用__call__
函数,并将其返回值作为新函数返回值,只是会进行一步额外信息打印操作。
之后对实参进行处理,这里的_process_arg
主要做了参数校验的工作,判断了实参和形参的个数是否统一,并将*args
解构赋值到一个list中返回。
之后判断了是否启用了自动微分,在本例中未使用自动微分,这一步也不用考虑。之后就进入到了我们的编译环节,ensure_compiled
函数代码如下:
def ensure_compiled(self, *args):
instance_id, arg_features = self.mapper.lookup(args)
key = (self.func, instance_id, self.autodiff_mode)
self.materialize(key=key, args=args, arg_features=arg_features)
return key
这里调用TaichiCallableTemplateMapper
的loopup
功能获得的instance_id和arg_features分别为0和两个'#'
字符串组成的list。
def materialize(self, key=None, args=None, arg_features=None):
if key is None:
key = (self.func, 0, self.autodiff_mode)
self.runtime.materialize()
if key in self.runtime.compiled_functions:
return
grad_suffix = ""
......
kernel_name = f"{self.func.__name__}_c{self.kernel_counter}_{key[1]}{grad_suffix}"
_logging.trace(f"Compiling kernel {kernel_name}...")
tree, ctx = _get_tree_and_ctx(
self,
args=args,
excluded_parameters=self.template_slot_locations,
arg_features=arg_features)
if self.autodiff_mode != AutodiffMode.NONE:
KernelSimplicityASTChecker(self.func).visit(tree)
def taichi_ast_generator(kernel_cxx):
......
taichi_kernel = impl.get_runtime().prog.create_kernel(
taichi_ast_generator, kernel_name, self.autodiff_mode)
self.kernel_cpp = taichi_kernel
assert key not in self.runtime.compiled_functions
self.runtime.compiled_functions[key] = self.get_function_body(
taichi_kernel)
self.compiled_kernels[key] = taichi_kernel
上面是Kernel中的materialize
函数,可以看到这里又调用了一次ProgramImpl的materialize,这个在之前Init的时候介绍过,主要是Jit相关的设置。之后进行判断是否进行过了编译,jit编译后的二进制代码会存放在内存中的,所以只需要第一次执行时进行编译即可,所以这里如果发现该kernel已经经过编译了,则这一步可以跳过。之后我们可以看到JIT函数的规则,Python本身的函数名+当前kernel序号构成,后面是一些附加信息,这一串规则可以巧妙的避免命名重复导致错误,同时比使用Uuid等方式随机产生的一大段随机字符串在长度方面小很多。我们此处的名称为play_c76_0
这个也是我们之后实际调用JIT的函数名称。接下来这一段是用来获取Python AST树的:
def _get_tree_and_ctx(self,
excluded_parameters=(),
is_kernel=True,
arg_features=None,
args=None,
ast_builder=None,
is_real_function=False):
file = getsourcefile(self.func)
src, start_lineno = getsourcelines(self.func)
src = [textwrap.fill(line, tabsize=4, width=9999) for line in src]
tree = ast.parse(textwrap.dedent("\n".join(src)))
func_body = tree.body[0]
func_body.decorator_list = []
global_vars = _get_global_vars(self.func)
for i, arg in enumerate(func_body.args.args):
anno = arg.annotation
if isinstance(anno, ast.Name):
global_vars[anno.id] = self.arguments[i].annotation
if isinstance(func_body.returns, ast.Name):
global_vars[func_body.returns.id] = self.return_type
if is_kernel or is_real_function:
# inject template parameters into globals
for i in self.template_slot_locations:
template_var_name = self.arguments[i].name
global_vars[template_var_name] = args[i]
return tree, ASTTransformerContext(excluded_parameters=excluded_parameters,
is_kernel=is_kernel,
func=self,
arg_features=arg_features,
global_vars=global_vars,
argument_data=args,
src=src,
start_lineno=start_lineno,
file=file,
ast_builder=ast_builder,
is_real_function=is_real_function)
又是我们的老朋友inspect
模块,其实整体逻辑还是比较简单的,通过inspect模块获取到了所执行的文件路径和所执行的函数源码。之后通过_get_global_vars
获取了全局变量,此处对闭包做了特殊处理。随后将参数和返回值类型加入到全局变量集合中。之后返回了ast树和taichi自己定义的用于做转换的Python类ASTTransformerContext
(代码就不PO了,其实基本上算是一个Data类了,在初始化时没有做特殊操作)。
之后我们回到materialize
函数中,如果此时开启了auto_diff,则会对AST树进行解析,做一些语义检测,这里采取的是访问器模式,之后会详细开一期介绍taichi 在启用不同模式auto_diff情况下的生命周期。
随后调用taichi的c++库来创建C++ Kernel对象并加入到当前pytaichi(上文提到的全局PyTaichi对象)的program中去。我们来看一下create_kernel的c++实现吧,这里其实在c++ pybind文件中,是一个匿名函数
.def(
"create_kernel",
[](Program *program, const std::function<void(Kernel *)> &body,
const std::string &name, AutodiffMode autodiff_mode) -> Kernel * {
py::gil_scoped_release release;
return &program->kernel(body, name, autodiff_mode);
},
Kernel &Program::kernel(const std::function<void(Kernel *)> &body,
const std::string &name = "",
AutodiffMode autodiff_mode = AutodiffMode::kNone) {
// Expr::set_allow_store(true);
auto func = std::make_unique<Kernel>(*this, body, name, autodiff_mode);
// Expr::set_allow_store(false);
kernels.emplace_back(std::move(func));
return *kernels.back();
}
注意这里pybind中的py::gil_scoped_release release;
这句话,这个是为了多线程情况下对GIL锁的处理,Python C API 规定全局解释器锁 (GIL) 必须始终由当前线程持有才能安全访问 Python 对象。因此,当 Python 通过 pybind11 调用 C++ 时,必须持有 GIL,而 pybind11 永远不会隐式释放 GIL,所以我们就需要手动释放和获取锁。
pybind11 需要确保它正在调用 Python 代码时保持 GIL。如上,这里有趣的一点是我们这里传入的其实是一个Python函数给std::function
类型参数作为一个回调,当c++调用该回调函数时需要确保持有GIL,这段会很复杂,我会在后续c++部分的生命周期的时候话一段时间来详解这一块。
回到Python部分,接下来要处理的就是kernel装饰器所装饰的play函数本身了,get_function_body
具体内容简单来说就是非常爽快的返回了一个闭包。这个闭包我们之后会用到,这里就直接将他加入到了pytaichi的compiled_functions
字典中。之后将C++ Kernekl对象加入到当前Python Kernel类中的compiled_kernels
字典中。
在初始化Kernel阶段,会执行传入的回调函数,我们来具体看一下那个回调函数
def taichi_ast_generator(kernel_cxx):
if self.runtime.inside_kernel:
raise TaichiSyntaxError(......)
self.runtime.inside_kernel = True
self.runtime.current_kernel = self
try:
ctx.ast_builder = kernel_cxx.ast_builder()
transform_tree(tree, ctx)
if not ctx.is_real_function:
if self.return_type and ctx.returned != ReturnStatus.ReturnedValue:
raise TaichiSyntaxError(
"Kernel has a return type but does not have a return statement"
)
finally:
self.runtime.inside_kernel = False
self.runtime.current_kernel = None
其实这个回调函数传入的kernel_cxx就是刚刚创建的c++的kernel类,获取的ast_build就FrotentContext默认的ASTBuilder, FrontentContext在创建Kernel示例时会创建。
之后运行transform_tree函数,这一步其实这个回调最关键的时刻:
这一步稍显复杂,这里通过递归来完成了build
至此ensure_compiled
功能就走完了,回到最初的__call__
函数,最后操作即是从pytaichi的compiled_functions
取出之前加入的闭包,并将闭包的结果当做kernel运行的结果返回。到此具体的编译过程正式展开(上面存在一步编译了,利用ast模块对play源码进行了parser,获取到了AST树,这里其实有一定优化空间,Python在对py文件整体编译时保存了AST树,可惜的是无法从Python语言层面去获取到这个内部对象,不过可以拓展cpython来实现这一功能,但是稍显鸡肋,小量代码的parser速度是非常快的,这一部分提升可以忽略不计了)。