1. 起因
曾经碰到过别人的模型prelu在内部的推理引擎算出的结果与其在原始框架PyTorch中不一致的情况,虽然理论上大家实现的都是一个算法,但是从参数上看,因为经过了模型转换,中间做了一些调整。为了确定究竟是初始参数传递就出了问题还是在后续传递过程中继续做了更改、亦或者是最终算法实现方面有着细微差别导致最终输出不同,就想着去看一看PyTorch一路下来是怎么做的。
但是代码跟着跟着就跟丢了,才会发现,PyTorch真的是一个很复杂的项目,但就像舌尖里面说的,环境越是恶劣,回报越是丰厚。为了以后再想跟踪的时候方便,因此决定以PReLU为例静态梳理一下PyTorch的代码结构。捣鼓的这些天,对如何构建一个带有C/C++代码的Python又有了新的了解,这也算是意外的收获吧。
2. 历程
首先,我们从PReLU的导入路径torch.nn.PReLU中知道,他应在径进torch\nn\之下,进入该路径虽然没看到,但是我们在该路径下的__init__.py中知道,其实它就在torch\nn\modules\activation.py中。类PReLU最终调用了从torch\nn\functional.py导入的prelu方法。顺腾摸瓜,找到prelu,它长下面这样:
def prelu(input, weight):
# type: (Tensor, Tensor) -> Tensor
if not torch.jit.is_scripting():
if type(input) is not Tensor and has_torch_function((input,)):
return handle_torch_function(prelu, (input,), input, weight)
return torch.prelu(input, weight)
经过人脑对代码的一番执行你会发现,第一个if条件满足,而第二个if不满足。因此,最终想看算法,得去看torch.prelu()。好吧,接着干……
一番搜寻之后你会发现,Python代码中在torch这个包下面你是找不到prelu的定义的。但是绝望之际我们在torch包的__init__.py之中看到看下面几行代码:
# pytorch\torch\__init__.py
# 为了简洁,省去不必要代码,详细代码参见pytorch\torch\__init__.py
try:
# _initExtension is chosen (arbitrarily) as a sentinel.
from torch._C import _initExtension
__all__ += [name for name in dir(_C)
if name[0] != '_' and
not name.endswith('Base')]
if TYPE_CHECKING:
# Some type signatures pulled in from _VariableFunctions here clash with
# signatures already imported. For now these clashes are ignored; see
# PR #43339 for details.
from torch._C._VariableFunctions import * # type: ignore
for name in dir(_C._VariableFunctions):
if name.startswith('__'):
continue
globals()[name] = getattr(_C._VariableFunctions, name)
__all__.append(name)
这是全村最后的希望了。我们知道__all__中的名字其实就是该模块有意暴露出去的API。
什么意思呢?也就是说虽然我们明文上已经看不到了prelu的定义,但是这几行代码表明有一大堆身份不明的API被暗搓搓的导入了,这其中就很有可能存在我们朝思暮想的prelu。
那么我们怎么凭借这么一点微弱的线索确定我们的猜测到底对不对呢?这里我们就用到了Python的一个关键知识:C/C++扩展。(戳这里《使用C语言编写Python模块-引子》《Python调用C++之PYBIND11简介》了解更多)
我们知道Python C/C++扩展有着固定的格式,只要我们找到模块初始化入口,就能顺藤摸瓜找到该模块暴露的给Python解释器所有函数。Python 3中的初始化函数样子为PyInit_,其中就是模块的名字。例如在前面提到的from torch._C import *中,模块torch下面必要有一个名字为_C的子模块。因此它的初始化函数应该为PyInit__C,我们搜索该名字就能找到模块入口。当然另外还有一种方法,就是查看setup.py文件中关于扩展的描述信息:
// pytorch\setup.py
main_sources = ["torch/csrc/stub.c"]
C = Extension("torch._C",
libraries=main_libraries,
sources=main_sources,
language='c',
extra_compile_args=main_compile_args + extra_compile_args,
include_dirs=[],
library_dirs=library_dirs,
extra_link_args=extra_link_args + main_link_args + make_relative_rpath_args('lib'))
extensions.append(C)
不管是通过搜索还是查看setup.py,我们最终都成功定位到了位于pytorch\torch\csrc\stub.c下的模块初始化函数PyInit__C(void),并进一步跟踪其调用的函数initModule(),便可以知道具体都暴露了哪些API给Python解释器。
// pytorch\torch\csrc\stub.c
PyMODINIT_FUNC PyInit__C(void)
{
return initModule();
}
// pytorch\torch\csrc\Module.cpp
initModule()
进入initModule()寻找一番,你会发现,模块_C中依然没有prelu的Python接口。怎么办?莫慌,通过前面对torch.__init_