pytorch native functions调用 调研
本片文章调研的是pytorch通过native_functions.yaml注册的自定义算子,在调用的时候,调用栈是什么样的
测试用的python文件是
import torch
a = torch.randn(5,3,requires_grad=True)
b = torch.randn(5,3,requires_grad=True)
print(a.myrelu())
print(a.relu())
a.relu().backward(b)
print(a.grad)
a.myrelu().backward(b)
print(a.grad/2)
调用栈中和pytorch相关的最底层的栈是
static PyObject * THPVariable_myrelu(PyObject* self_, PyObject* args)
{
HANDLE_TH_ERRORS
Tensor& self = reinterpret_cast<THPVariable*>(self_)->cdata;
if(check_has_torch_function(self_)) {
return handle_torch_function(self_, "myrelu");
}
// aten::myrelu(Tensor self) -> Tensor
auto dispatch_myrelu = [](Tensor & self) -> Tensor {
pybind11::gil_scoped_release no_gil;
return self.myrelu();
};
return wrap(dispatch_myrelu(self));
END_HANDLE_TH_ERRORS
}
由于调用的是一个变量自身的方法,所以进入的是python_variable_methods.cpp
这个文件是编译的时候自动生成的
这里定义了一个lambda函数,真正的调用发生在self.myrelu()
而后,进入TensorMethods.cpp。这个文件也是自动生成的
// aten::myrelu(Tensor self) -> Tensor
Tensor Tensor::myrelu() const {
#ifdef USE_STATIC_DISPATCH
at::AutoNonVariableTypeMode _var_guard(true);
DispatchKeySet _dk_set = c10::detail::multi_dispatch_key_set(*this);
DispatchKeySet _dk_mask = c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::BackendSelect);
DispatchKey _dk = c10::impl::dispatchTypeId(_dk_set, _dk_mask);
switch (dispatchKeyToBackend(_dk)) {
<