背景
在PyTorch的Tensor系列上一篇文章中:
Gemfield:PyTorch的Tensor(上)zhuanlan.zhihu.comGemfield介绍了一个Tensor的创建过程,特别是在创建一个Tensor的时候,调用栈从Python到C++再回到Python的过程。与此同时,在内存中对应的是一个Variable实例的创建(严格来说,Variable实例的某个field也是Variable实例)。
在本文,Gemfield将介绍PyTorch的Tensor中autograd相关的部分。autograd是PyTorch之所以是神经网络框架的一个重要原因。autograd机制提供了对Tensor上所有操作自动求微分的功能。我们知道,对于一个Variable来说,它的唯一数据成员就是impl_,这个impl_成员是TensorImpl 类型,在初始化阶段impl_会被实例化为一个Variable::Impl的实例(TensorImpl的子类):
Variable --> impl_ = Variable::Impl实例
对于一个Variable的autograd来说,autograd的部分就体现在impl_的autograd_meta_成员上。在初始化阶段,autograd_meta_会被初始化为一个Variable::AutogradMeta的实例(AutogradMetaInterface的子类),或者会被初始化为一个Variable::DifferentiableViewMeta的实例(Variable::AutogradMeta的子类),然后通过Variable的 get_autograd_meta()来访问。实际上,autograd_meta_正是一个Variable是普通tensor还是带autograd功能的tensor的唯一标识:
#1 Variable是个Tensor,没有requires_grad
Variable --> impl_ --> autograd_meta_ = None
#2
Variable --> impl_ --> autograd_meta_ = Variable::AutogradMeta实例
#3
Variable --> impl_ --> autograd_meta_ = Variable::DifferentiableViewMeta实例
而一个Variable::AutogradMeta实例有如下成员,这些成员正是PyTorch autograd系统的中坚:
# Variable::AutogradMeta 和 Variable::DifferentiableViewMeta
Variable grad_;
std::shared_ptr<Function> grad_fn_;
std::weak_ptr<Function> grad_accumulator_;
VariableVersion version_counter_;
std::vector<std::shared_ptr<FunctionPreHook>> hooks_;
bool requires_grad_;
bool is_view_;
uint32_t output_nr_;
# 仅 Variable::DifferentiableViewMeta
Variable base_;
uint32_t attr_version;
- 1,grad_是另外一个Variable,存储当前Variable实例的梯度;
- 2,grad_fn是个Function的实例,非leaf variables才有。通过Variable的grad_fn()来访问,实际上,PyTorch中就是通过是否grad_fn_ == nullptr来判断一个Variable是否是leaf variable的;
- 3,grad_accumulator_是个Function的实例,只有leaf variables才有。通过Variable的grad_accumulator()来访问;
- 4,version_counter_里有个version number;
- 5,hooks_可以是一组;
- 6,requires_grad_ 是个flag,表明此Variable实例是否需要grad;
- 7,is_view_是个flag,表明此Variable实例是否是个view(没有实际存储,基于base的variable);
- 8,output_nr_是个数字;
- 9,base_是view的base variable;
- 10,attr_version是个数字。
我们通过下面这一小段代码来演示下这个能力:
gemfield = torch.ones(2, 2, requires_grad=True)
syszux = gemfield + 2
civilnet = syszux * syszux * 3
gemfieldout = civilnet.mean()
gemfieldout.backward()
特别的,对于在python会话中的每一步操作,gemfield都将映射到内存上类实例中的成员/结构体的变化。
Tensor创建:gemfield = torch.ones(2, 2, requires_grad=True)
我们使用gemfield = torch.ones(2, 2, requires_grad=True) 语句来创建了一个tensor。在https://zhuanlan.zhihu.com/p/54896021一文中已经介绍过了,这个调用会在内存中产生如下一个Variable实例:
#gemfield
Variable实例 --> Variable::Imple实例 --> tensor data --> TensorImpl实例 --> Storage实例 = [[1., 1.],[1., 1.]]
--> autograd_meta --> grad_ (又一个Variable实例) = None
--> grad_fn_ (Function实例)= None
--> grad_accumulator_ (Function实例)= None
--> version_counter_ = 0
--> hooks_ len = 0
--> requires_grad_ = True
--> is_view_ = false
--> output_nr_ = 0
--> base_ = Not exist
这个gemfield变量就是图中的leaf,为什么呢?因为这是用户直接创建的(不是经过计算得到的),位于图中最“底端/外侧”的位置,没有子节点。这个时候,Tensor gemfield的grad是None,grad_fn是None。output_nr_为0,表明这个Variable是function的第1个输出。
Tensor的简单加法:syszux = gemfield + 2
我们使用 syszux = gemfield + 2 来得到一个新的Tensor,名字为syszux。这个加法嘛,在初始化的时候已经和C++中的THPVariable_add函数绑定上,并注册到Python的torch._C._TensorBase符号上了:
PyMethodDef variable_methods[] = {
{
"__add__", (PyCFunction)THPVariable_add, METH_VARARGS | METH_KEYWORDS, NULL},
......
而THPVariable_add的定义如下:
static PyObject * THPVariable_add(PyObject* self_, PyObject* args, PyObject* kwargs)
{
......
return wrap(dispatch_add(self, r.tensor(0), r.scalar(1)));
}
1,scalar to tensor
在这个函数中,首先要将syszux = gemfield + 2 中的2从标量转换为tensor,这个转换逻辑如下:
auto tensor = scalar_to_tensor(scalar);
tensor.unsafeGetTensorImpl()->set_wrapped_number(true);
return autograd::make_variable(tensor);
现在scalar 2已经变成了内存中一个Variable的实例,在add真正执行之前,在内存中已经有2个Variable实例了,分别是gemfield和2:
#gemfield
Variable实例 --> Variable::Imple实例 --> tensor data --> TensorImpl实例 --> Storage实例 = [[1., 1.],[1., 1.]]
--> autograd_meta --> grad_ (又一个Variable实例) = None
--> grad_fn_ (Function实例)= None
--> grad_accumulator_ (Function实例)= None
--> version_counter_ = 0
--> hooks_ len = 0
--> requires_grad_ = True
--> is_view_ = false
--> output_nr_ = 0
--> base_ = Not exist
#scalar 2
Variable实例 --> Variable::Imple实例 --> tensor data --> Ten