1. Kernel编译
回到transform_tree这个方法,其实这里在初始化Kernel的时候就已经调用了,transform方法很简单就是实例化了一个ASTTransformer对象,并调用了其__call__
方法
这里call重载在父类Builder中:
def __call__(self, ctx, node):
method = getattr(self, 'build_' + node.__class__.__name__, None)
try:
if method is None:
error_msg = f'Unsupported node "{node.__class__.__name__}"'
raise TaichiSyntaxError(error_msg)
info = ctx.get_pos_info(node) if isinstance(
node, (ast.stmt, ast.expr)) else ""
with impl.get_runtime().src_info_guard(info):
return method(ctx, node)
except Exception as e:
if ctx.raised or not isinstance(node, (ast.stmt, ast.expr)):
raise e.with_traceback(None)
ctx.raised = True
e = handle_exception_from_cpp(e)
if not isinstance(e, TaichiCompilationError):
msg = ctx.get_pos_info(node) + traceback.format_exc()
raise TaichiCompilationError(msg) from None
msg = ctx.get_pos_info(node) + str(e)
raise type(e)(msg) from None
代码量不大,但却很精华,使用了递归的方式去遍历解析tree。
这里需要稍微解释一下with关键字,with结构是python中非常好玩的结构,大致可以理解成一种包围,这里with的是一个SrcInfoGuard
类型,这里会在执行with内部语句前调用所with对象的__enter__
方法,退出with结构时调用__exit__
方法,我们这里看一下SrcInfoGuard结构:
class SrcInfoGuard:
def __init__(self, info_stack, info):
self.info_stack = info_stack
self.info = info
def __enter__(self):
self.info_stack.append(self.info)
def __exit__(self, exc_type, exc_val, exc_tb):
self.info_stack.pop()
所以这里会在执行method方法前先将info加入到info_stack中,往往也是这种需要栈结构的情况下可以使用with语句。这也是Python比较方便的地方。
我们继续走,这里的method是build_Module,来到build_Module函数:
build_stmt = ASTTransformer()
@staticmethod
def build_Module(ctx, node):
with ctx.variable_scope_guard():
# Do NOT use |build_stmts| which inserts 'del' statements to the
# end and deletes parameters passed into the module
for stmt in node.body:
build_stmt(ctx, stmt)
return None
这里的node.body获取的是一个FunctionDef list,当然这里只有一个元素就是我们上面的play函数,这个FunctionDef类其实CPython中的C 结构体,我们在之前的CPython源码中有提及过,这里再次贴出他的结构:
struct {
identifier name;
arguments_ty args;
asdl_seq *body;
asdl_seq *decorator_list;
expr_ty returns;
string type_comment;
} FunctionDef;
之后再次调用了ASTTransformer的__call__
函数,继续经过with语句,顺利的函数info信息加入到info_stack列表中,这次获取的method是 build_FunctionDef
, 进入FunctionDef:
@staticmethod
def build_FunctionDef(ctx, node):
if ctx.visited_funcdef:
raise TaichiSyntaxError(
f"Function definition is not allowed in 'ti.{'kernel' if ctx.is_kernel else 'func'}'."
)
ctx.visited_funcdef = True
args = node.args
assert args.vararg is None
assert args.kwonlyargs == []
assert args.kw_defaults == []
assert args.kwarg is None
开头是一段检验,包括该节点是否访问,参数类型是否合法(这些在之前也进行了检测,taichi的kernel不允许关键字参数,可变参数等)
def transform_as_kernel():
# Treat return type
if node.returns is not None:
kernel_arguments.decl_ret(ctx.func.return_type,
ctx.is_real_function)
impl.get_runtime().prog.finalize_rets()
for i, arg in enumerate(args.args):
if not isinstance(ctx.func.arguments[i].annotation,
primitive_types.RefType):
ctx.kernel_args.append(arg.arg)
if isinstance(ctx.func.arguments[i].annotation,
annotations.template):
ctx.create_variable(arg.arg, ctx.global_vars[arg.arg])
......
else:
ctx.create_variable(
arg.arg,
kernel_arguments.decl_scalar_arg(
ctx.func.arguments[i].annotation))
# remove original args
node.args.args = []
之后是一个闭包函数transform_as_kernel
,这里首先处理返回值,我们的例子中没有返回值,暂时不考虑。之后是处理了函数行参,这里会对taichid几个特殊类型做判断,当然我们的例子中并不包含这些特殊类型,将参数的名称加入到ctx的kernel_args列表中,最后创建variable context:
def decl_scalar_arg(dtype):
is_ref = False
if isinstance(dtype, RefType):
is_ref = True
dtype = dtype.tp
dtype = cook_dtype(dtype)
arg_id = impl.get_runtime().prog.decl_scalar_arg(dtype)
return Expr(_ti_core.make_arg_load_expr(arg_id, dtype, is_ref))
上面的是decl_scalar_arg
具体是创建了一个Taichi Cpp Expr对象。
with ctx.variable_scope_guard():
build_stmts(ctx, node.body)
最后就是除了函数体,具体就不看了,和处理函数定义类似的方式,就是前端遍历中AST树的创建罢了。
2. Kernel调用
回到我们之前存储的闭包,这一次我们可以好好的分析一下这段代码:
def func__(*args):
assert len(args) == len(
self.arguments
), f'{len(self.arguments)} arguments needed but {len(args)} provided'
tmps = []
callbacks = []
actual_argument_slot = 0
launch_ctx = t_kernel.make_launch_context()
for i, v in enumerate(args):
needed = self.arguments[i].annotation
if isinstance(needed, template):
continue
provided = type(v)
# Note: do not use sth like "needed == f32". That would be slow.
if id(needed) in primitive_types.real_type_ids:
if not isinstance(v, (float, int)):
raise TaichiRuntimeTypeError.get(
i, needed.to_string(), provided)
launch_ctx.set_arg_float(actual_argument_slot, float(v))
elif id(needed) in primitive_types.integer_type_ids:
if not isinstance(v, int):
raise TaichiRuntimeTypeError.get(
i, needed.to_string(), provided)
if is_signed(cook_dtype(needed)):
launch_ctx.set_arg_int(actual_argument_slot, int(v))
actual_argument_slot += 1
......
try:
t_kernel(launch_ctx)
except Exception as e:
e = handle_exception_from_cpp(e)
raise e from None
这一段代码非常长,我们截取来看,首先检测了形参和实参是否一致(话说,之前ensure的时候不是检测过一次吗?)初始化了两个局部变量tmps、callbacks
之后创造了一个LauchContextBuilder对象,这一步依靠pybind绑定的c++实现。之后就是对实参的处理,处理好的参数全部加载到LaouchContext中,最后调用之前创建好的t_kernel,至此python前置部分运行完毕