PyTorch源码浅析:简介
这个系列文章自底向上针对PyTorch核心源码进行解析,从Tensor库→神经网络算符→自动微分引擎→Python扩展,一共五篇。代码较多,理解有限,如发现理解不当或表达不妥的地方,还请在评论区指出。
目录
1. THTensor
PyTorch中Tensor的存储和表示分开,多个THTensor可能共享一个THStorage,每个THTensor可能拥有不同的view(e.g. size, stride)。这样设计的好处是,有时看起来不一样的数据底层是共享的,比如矩阵与矩阵的转置、二维矩阵与二维矩阵变成一维时的矩阵。这部分的主要实现在pytorch/aten文件夹中,这里面既实现了底层的Tensor操作库,也封装了名为 ATen 的 C++11接口。
2. THC
这篇主要看 Torch CUDA 部分,对应源码目录aten/src/THC,里面包含了许多C++和CUDA代码。这部分实现了操作 THCTensor 和 THCStorage 的接口,不过底层用的数据结构还是TensorImpl和StorageImpl。THC里的接口也是通过C语言范式实现的,但是Apply系列操作不再由宏来实现,而是使用了C++模板。其他的区别还有allocator不同,以及多了 THCState 结构。
3. NN
THNN是一个用C语言实现的神经网络模块的库,提供的功能非常底层。它实现了许多基础的神经网络模块,包括线性层,卷积层,Sigmoid等各种激活层,一些基本的loss函数,这些API都声明在THNN/generic/THNN.h中。每个模块都实现了前向传导(forward)和后向传导(backward)的功能。THCUNN则是对应模块的CUDA实现。
4. Autograd
这篇博客介绍 PyTorch 中自动微分引擎的实现,主要分为三部分:首先简要介绍一下计算图的原理;然后介绍 PyTorch 中与 autograd 的相关数据结构和backward()函数的实现,数据结构包括torch::autograd::Variable,torch::autograd::Function等;最后讲一下动态建立计算图的实现,这部分代码涉及到动态派发机制,而且都是用脚本生成的,不太容易理解。
5. Python扩展
这篇是本系列最后一篇博客了,介绍一下前面的C++代码怎么与Python交互,或者说Python里怎么调用C++代码进行高效的计算。首先简单介绍一下预备知识,既Python的C扩展通常怎么写;然后以比较核心的数据结构 Tensor 和 Storage 为例看一下它们怎么转换为Python类型的;最后稍带点儿Python自微分函数的实现。
源码目录结构
pytorch
├── aten # ATen: C++ Tensor库
│ ├── CMakeLists.txt
│ ├── conda
│ ├── src
│ │ ├── ATen # C++ bindings
│ │ ├── README.md
│ │ ├── TH # torch tensor
│ │ ├── THC # torch cuda
│ │ ├── THCUNN # torch cuda nn
│ │ └── THNN # torch nn
│ └── tools
├── c10 # 这里面也包含一些Tensor实现
│ ├── CMakeLists.txt
│ ├── core
│ ├── cuda
│ ├── hip
│ ├── macros
│ ├── test
│ └── util
├── caffe2 # caffe2
├── tools
│ ├── autograd # 生成自微分相关函数的工具
│ ├── ...
│ └── shared
├── torch # Python模块
│ ├── autograd
│ ├── csrc # C++相关源码
│ │ ├── autograd # 自动微分引擎实现
│ │ ├── cuda
│ │ ├── distributed
│ │ ├── generic
│ │ ├── jit
│ │ ├── multiprocessing
│ │ ├── nn
│ │ ├── tensor
│ │ ├── utils
│ │ ├── ...
│ │ └── utils.h
│ ├── cuda
│ ├── nn
│ ├── ...
│ ├── storage.py
│ └── tensor.py
├── ...
└── ubsan.supp
代码统计
注:仅统计了torch和aten两个核心文件夹。
感受
一开始只是心血来潮觉得这学期反正不是很忙就立了个flag决定学期内把PyTorch源码看一遍,看的过程很受苦,庆幸最终还是坚持下来了,收获也很大。除了理解了PyTorch是如何运行的、输出这五篇博客之外,我对C++的理解也有显著提升,因为PyTorch大部分代码是用C++写的,各种新特性简直刷新了我对这门语言的认识,由此也专门记了一篇关于C++的笔记。
简单说一下我的阅读方法。面对这么多的代码和文件,一下子肯定不知所措,尤其是阅读新模块的时候,我首先会尝试找到该模块的说明,通过README.md或前人的博客或API文档了解下该模块大概功能和结构,然后整体(粗略)浏览一遍该模块的代码,对每个文件里的代码是做什么的有个大致概念,最后再根据自己的理解选择性地进行精读。