(三)Python部分生命周期之Kernel

文章详细介绍了Taichi编程框架中@kernel装饰器的工作原理,包括如何判断kernel是否在类中定义,以及Kernel类的初始化和编译过程。在初始化Kernel时,获取函数参数并进行类型检查,然后在第一次调用时进行编译。编译涉及AST树的生成和C++内核的创建。此外,文章还提到了自动微分模式和多线程中的GIL管理。
摘要由CSDN通过智能技术生成

1. kernel装饰器

之后我们看一下@ti.kernel装饰器

def _kernel_impl(_func, level_of_class_stackframe, verbose=False):
    is_classkernel = _inside_class(level_of_class_stackframe + 1)
    if verbose:
        print(f'kernel={_func.__name__} is_classkernel={is_classkernel}')
    primal = Kernel(_func,
                    autodiff_mode=AutodiffMode.NONE,
                    _classkernel=is_classkernel)
    adjoint = Kernel(_func,
                     autodiff_mode=AutodiffMode.REVERSE,
                     _classkernel=is_classkernel)
    primal.grad = adjoint
    if is_classkernel:
        @functools.wraps(_func)
        def wrapped(*args, **kwargs):
            clsobj = type(args[0])
            assert not hasattr(clsobj, '_data_oriented')
            raise TaichiSyntaxError(
                f'Please decorate class {clsobj.__name__} with @ti.data_oriented'
            )
    else:
        @functools.wraps(_func)
        def wrapped(*args, **kwargs):
            try:
                return primal(*args, **kwargs)
            except (TaichiCompilationError, TaichiRuntimeError) as e:
                raise type(e)('\n' + str(e)) from None
        wrapped.grad = adjoint
    wrapped._is_wrapped_kernel = True
    wrapped._is_classkernel = is_classkernel
    wrapped._primal = primal
    wrapped._adjoint = adjoint
    return wrapped

def kernel(fn):
    return _kernel_impl(fn, level_of_class_stackframe=3)

kernel是一个装饰器函数,其内部调用了_kernel_impl作为具体实现。这里首先判断了当前的kernel是否在一个类中,是的,taichi允许你在一个类中定义kernel,这给了开发者极大的便利,同时taichi用了非常巧妙的方式来处理在类中的Kernel,这里我们暂时不讨论,我先关注于在类外面的kernel函数。而_inside_class方法具体是如何判断是否在类中的,则用了我们的老朋友,inspect模块来获取解释器堆栈上的语句信息,通过正则匹配检测class关键字实现的。

接下来实例化了两个Kernel类,这个Kernel类是taichi程序的编译核心,也是执行入口,具体编译过程在Kernel的__call__魔法函数中。

之后是一个内部闭包wrapped,functools.wrap装饰器的作用是不改变调用对象,单纯改变所装饰对象的一些属性,也就是这里最后返回的是wrapped函数,但是属性则依然是传入的play函数,所以我们打印play函数的__name____doc__属性就会发现没有发生改变,但是具体调用逻辑则是调用Kernel的__call__函数。

那我们接下来就很好看一下Kernel这个类:

class Kernel:
    counter = 0

    def __init__(self, _func, autodiff_mode, _classkernel=False):
        self.func = _func
        self.kernel_counter = Kernel.counter
        Kernel.counter += 1
        assert autodiff_mode in (AutodiffMode.NONE, AutodiffMode.VALIDATION,
                                 AutodiffMode.FORWARD, AutodiffMode.REVERSE)
        self.autodiff_mode = autodiff_mode
        self.grad = None
        self.arguments = []
        self.return_type = None
        self.classkernel = _classkernel
        self.extract_arguments()
        self.template_slot_locations = []
        for i, arg in enumerate(self.arguments):
            if isinstance(arg.annotation, template):
                self.template_slot_locations.append(i)
        self.mapper = TaichiCallableTemplateMapper(
            self.arguments, self.template_slot_locations)
        impl.get_runtime().kernels.append(self)
        self.reset()
        self.kernel_cpp = None
        # TODO[#5114]: get rid of compiled_functions and use compiled_kernels instead.
        # Main motivation is that compiled_kernels can be potentially serialized in the AOT scenario.
        self.compiled_kernels = {}
        self.has_print = False

在初始化Kernel时,把所装饰的函数作为了当前func属性,获取当前kernel的计数器,之后静态成员变量counter+1。判断是否开启auto_diff,是否有自动微分等。之后通过extract_argument来获取函数形参参数:

def extract_arguments(self):
    sig = inspect.signature(self.func)
    if sig.return_annotation not in (inspect._empty, None):
        self.return_type = sig.return_annotation
    params = sig.parameters
    arg_names = params.keys()
    for i, arg_name in enumerate(arg_names):
        param = params[arg_name]
        if param.kind == inspect.Parameter.VAR_KEYWORD:
            raise TaichiSyntaxError(
                'Taichi kernels do not support variable keyword parameters (i.e., **kwargs)'
            )
		......
        annotation = param.annotation
        if param.annotation is inspect.Parameter.empty:
            if i == 0 and self.classkernel:  # The |self| parameter
                annotation = template()
            else:
                raise TaichiSyntaxError(
                    'Taichi kernels parameters must be type annotated')
		......
        self.arguments.append(
            KernelArgument(annotation, param.name, param.default))

获取的方式当然又是我们的老熟人inspect模块,通过signature获取到func上的形参和返回值类型属性,之后是一大段参数类型的异常判断,在这里我们看到关键字参数,带初始值的参数都是不支持的。之后将这些参数加入到Kernel类的成员变量arguments中,Taichi在这里对参数又进行了一层包装,不过KernelArgument就是一个非常简单的data类了:

class KernelArgument:
    def __init__(self, _annotation, _name, _default=inspect.Parameter.empty):
        self.annotation = _annotation
        self.name = _name
        self.default = _default

让我们回到Kernel的构造函数,获取到函数参数后对这些形参进行了遍历,判断是否是Template类型参数,这里的Template是一个用Python定义的一个Data类,用于taichi的元编程使用,taichi的模板编程实现的方式我们也暂且不表。再往下是初始化了一个TaichiCallbleTemplateMapper这个map存储了参数了信息为后续调用时使用。随后将现在的kernel加入到我们之前提到的pytaichi全局变量中,那是一个Program类,存储了全部的kernel。再之后就是调用了reset函数:

def reset(self):
    self.runtime = impl.get_runtime()

reset函数非常简单,把初始化当前runtime成员变量的方法提取出来以供复用,这里runtime所赋值的依然是pytaichi。到此Kernel的初始化工作基本完成了。

Taichi所采用的对Python实施静态编译的策略与我们之前探讨的pytorch和qcor有所不同,他并不是在装饰器加载时就完成了jit编译,而是在第一次调用时才进行编译,下面我们就探究一下Kernel的__call__函数:

2. kernel的运行入口

@_shell_pop_print
def __call__(self, *args, **kwargs):
    args = _process_args(self, args, kwargs)
    if self.runtime.fwd_mode_manager and not self.runtime.grad_replaced:
        self.runtime.fwd_mode_manager.insert(self)
    if self.autodiff_mode in (
            AutodiffMode.NONE, AutodiffMode.VALIDATION
    ) and self.runtime.target_tape and not self.runtime.grad_replaced:
        self.runtime.target_tape.insert(self, args)
    if self.autodiff_mode != AutodiffMode.NONE and impl.current_cfg(
    ).opt_level == 0:
        _logging.warn(
            """opt_level = 1 is enforced to enable gradient computation."""
        )
        impl.current_cfg().opt_level = 1
    key = self.ensure_compiled(*args)
    return self.runtime.compiled_functions[key](*args)

这段程序的代码量并不多,首先_shell_pop_print装饰器在绝大多数情况下冰不起作用,不会改变调用主体,只有在开启pybuf时会有一个额外处理,但依然会调用__call__函数,并将其返回值作为新函数返回值,只是会进行一步额外信息打印操作。

之后对实参进行处理,这里的_process_arg主要做了参数校验的工作,判断了实参和形参的个数是否统一,并将*args解构赋值到一个list中返回。

之后判断了是否启用了自动微分,在本例中未使用自动微分,这一步也不用考虑。之后就进入到了我们的编译环节,ensure_compiled函数代码如下:

def ensure_compiled(self, *args):
    instance_id, arg_features = self.mapper.lookup(args)
    key = (self.func, instance_id, self.autodiff_mode)
    self.materialize(key=key, args=args, arg_features=arg_features)
    return key

这里调用TaichiCallableTemplateMapperloopup功能获得的instance_id和arg_features分别为0和两个'#'字符串组成的list。

def materialize(self, key=None, args=None, arg_features=None):
    if key is None:
        key = (self.func, 0, self.autodiff_mode)
    self.runtime.materialize()
    if key in self.runtime.compiled_functions:
        return
    grad_suffix = ""
		......
    kernel_name = f"{self.func.__name__}_c{self.kernel_counter}_{key[1]}{grad_suffix}"
    _logging.trace(f"Compiling kernel {kernel_name}...")

    tree, ctx = _get_tree_and_ctx(
        self,
        args=args,
        excluded_parameters=self.template_slot_locations,
        arg_features=arg_features)

    if self.autodiff_mode != AutodiffMode.NONE:
        KernelSimplicityASTChecker(self.func).visit(tree)
    def taichi_ast_generator(kernel_cxx):
		......

    taichi_kernel = impl.get_runtime().prog.create_kernel(
        taichi_ast_generator, kernel_name, self.autodiff_mode)

    self.kernel_cpp = taichi_kernel

    assert key not in self.runtime.compiled_functions
    self.runtime.compiled_functions[key] = self.get_function_body(
        taichi_kernel)
    self.compiled_kernels[key] = taichi_kernel

上面是Kernel中的materialize函数,可以看到这里又调用了一次ProgramImpl的materialize,这个在之前Init的时候介绍过,主要是Jit相关的设置。之后进行判断是否进行过了编译,jit编译后的二进制代码会存放在内存中的,所以只需要第一次执行时进行编译即可,所以这里如果发现该kernel已经经过编译了,则这一步可以跳过。之后我们可以看到JIT函数的规则,Python本身的函数名+当前kernel序号构成,后面是一些附加信息,这一串规则可以巧妙的避免命名重复导致错误,同时比使用Uuid等方式随机产生的一大段随机字符串在长度方面小很多。我们此处的名称为play_c76_0这个也是我们之后实际调用JIT的函数名称。接下来这一段是用来获取Python AST树的:

def _get_tree_and_ctx(self,
                      excluded_parameters=(),
                      is_kernel=True,
                      arg_features=None,
                      args=None,
                      ast_builder=None,
                      is_real_function=False):
    file = getsourcefile(self.func)
    src, start_lineno = getsourcelines(self.func)
    src = [textwrap.fill(line, tabsize=4, width=9999) for line in src]
    tree = ast.parse(textwrap.dedent("\n".join(src)))

    func_body = tree.body[0]
    func_body.decorator_list = []

    global_vars = _get_global_vars(self.func)
    for i, arg in enumerate(func_body.args.args):
        anno = arg.annotation
        if isinstance(anno, ast.Name):
            global_vars[anno.id] = self.arguments[i].annotation

    if isinstance(func_body.returns, ast.Name):
        global_vars[func_body.returns.id] = self.return_type

    if is_kernel or is_real_function:
        # inject template parameters into globals
        for i in self.template_slot_locations:
            template_var_name = self.arguments[i].name
            global_vars[template_var_name] = args[i]

    return tree, ASTTransformerContext(excluded_parameters=excluded_parameters,
                                       is_kernel=is_kernel,
                                       func=self,
                                       arg_features=arg_features,
                                       global_vars=global_vars,
                                       argument_data=args,
                                       src=src,
                                       start_lineno=start_lineno,
                                       file=file,
                                       ast_builder=ast_builder,
                                       is_real_function=is_real_function)

又是我们的老朋友inspect模块,其实整体逻辑还是比较简单的,通过inspect模块获取到了所执行的文件路径和所执行的函数源码。之后通过_get_global_vars获取了全局变量,此处对闭包做了特殊处理。随后将参数和返回值类型加入到全局变量集合中。之后返回了ast树和taichi自己定义的用于做转换的Python类ASTTransformerContext(代码就不PO了,其实基本上算是一个Data类了,在初始化时没有做特殊操作)。

之后我们回到materialize函数中,如果此时开启了auto_diff,则会对AST树进行解析,做一些语义检测,这里采取的是访问器模式,之后会详细开一期介绍taichi 在启用不同模式auto_diff情况下的生命周期。

随后调用taichi的c++库来创建C++ Kernel对象并加入到当前pytaichi(上文提到的全局PyTaichi对象)的program中去。我们来看一下create_kernel的c++实现吧,这里其实在c++ pybind文件中,是一个匿名函数

.def(
    "create_kernel",
    [](Program *program, const std::function<void(Kernel *)> &body,
       const std::string &name, AutodiffMode autodiff_mode) -> Kernel * {
      py::gil_scoped_release release;
      return &program->kernel(body, name, autodiff_mode);
    },
      
Kernel &Program::kernel(const std::function<void(Kernel *)> &body,
                 const std::string &name = "",
                 AutodiffMode autodiff_mode = AutodiffMode::kNone) {
    // Expr::set_allow_store(true);
    auto func = std::make_unique<Kernel>(*this, body, name, autodiff_mode);
    // Expr::set_allow_store(false);
    kernels.emplace_back(std::move(func));
    return *kernels.back();
  }

注意这里pybind中的py::gil_scoped_release release;这句话,这个是为了多线程情况下对GIL锁的处理,Python C API 规定全局解释器锁 (GIL) 必须始终由当前线程持有才能安全访问 Python 对象。因此,当 Python 通过 pybind11 调用 C++ 时,必须持有 GIL,而 pybind11 永远不会隐式释放 GIL,所以我们就需要手动释放和获取锁。

pybind11 需要确保它正在调用 Python 代码时保持 GIL。如上,这里有趣的一点是我们这里传入的其实是一个Python函数给std::function类型参数作为一个回调,当c++调用该回调函数时需要确保持有GIL,这段会很复杂,我会在后续c++部分的生命周期的时候话一段时间来详解这一块。

回到Python部分,接下来要处理的就是kernel装饰器所装饰的play函数本身了,get_function_body具体内容简单来说就是非常爽快的返回了一个闭包。这个闭包我们之后会用到,这里就直接将他加入到了pytaichi的compiled_functions字典中。之后将C++ Kernekl对象加入到当前Python Kernel类中的compiled_kernels字典中。

在初始化Kernel阶段,会执行传入的回调函数,我们来具体看一下那个回调函数

def taichi_ast_generator(kernel_cxx):
    if self.runtime.inside_kernel:
        raise TaichiSyntaxError(......)
    self.runtime.inside_kernel = True
    self.runtime.current_kernel = self
    try:
        ctx.ast_builder = kernel_cxx.ast_builder()
        transform_tree(tree, ctx)
        if not ctx.is_real_function:
            if self.return_type and ctx.returned != ReturnStatus.ReturnedValue:
                raise TaichiSyntaxError(
                    "Kernel has a return type but does not have a return statement"
                )
    finally:
        self.runtime.inside_kernel = False
        self.runtime.current_kernel = None

其实这个回调函数传入的kernel_cxx就是刚刚创建的c++的kernel类,获取的ast_build就FrotentContext默认的ASTBuilder, FrontentContext在创建Kernel示例时会创建。

之后运行transform_tree函数,这一步其实这个回调最关键的时刻:

这一步稍显复杂,这里通过递归来完成了build

至此ensure_compiled功能就走完了,回到最初的__call__函数,最后操作即是从pytaichi的compiled_functions取出之前加入的闭包,并将闭包的结果当做kernel运行的结果返回。到此具体的编译过程正式展开(上面存在一步编译了,利用ast模块对play源码进行了parser,获取到了AST树,这里其实有一定优化空间,Python在对py文件整体编译时保存了AST树,可惜的是无法从Python语言层面去获取到这个内部对象,不过可以拓展cpython来实现这一功能,但是稍显鸡肋,小量代码的parser速度是非常快的,这一部分提升可以忽略不计了)。

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值