tensor判断是否相等_PyTorch的Tensor(中)

bf1b3dad2ae9992486c5a58414723598.png

背景

在PyTorch的Tensor系列上一篇文章中:

Gemfield:PyTorch的Tensor(上)​zhuanlan.zhihu.com
9c282381b1be088e399caada4e6c14a1.png

Gemfield介绍了一个Tensor的创建过程,特别是在创建一个Tensor的时候,调用栈从Python到C++再回到Python的过程。与此同时,在内存中对应的是一个Variable实例的创建(严格来说,Variable实例的某个field也是Variable实例)。

在本文,Gemfield将介绍PyTorch的Tensor中autograd相关的部分。autograd是PyTorch之所以是神经网络框架的一个重要原因。autograd机制提供了对Tensor上所有操作自动求微分的功能。我们知道,对于一个Variable来说,它的唯一数据成员就是impl_,这个impl_成员是TensorImpl 类型,在初始化阶段impl_会被实例化为一个Variable::Impl的实例(TensorImpl的子类):

Variable --> impl_ = Variable::Impl实例

对于一个Variable的autograd来说,autograd的部分就体现在impl_的autograd_meta_成员上。在初始化阶段,autograd_meta_会被初始化为一个Variable::AutogradMeta的实例(AutogradMetaInterface的子类),或者会被初始化为一个Variable::DifferentiableViewMeta的实例(Variable::AutogradMeta的子类),然后通过Variable的 get_autograd_meta()来访问。实际上,autograd_meta_正是一个Variable是普通tensor还是带autograd功能的tensor的唯一标识:

#1 Variable是个Tensor,没有requires_grad
Variable --> impl_ --> autograd_meta_ = None

#2
Variable --> impl_ --> autograd_meta_ = Variable::AutogradMeta实例

#3
Variable --> impl_ --> autograd_meta_ = Variable::DifferentiableViewMeta实例

而一个Variable::AutogradMeta实例有如下成员,这些成员正是PyTorch autograd系统的中坚:

# Variable::AutogradMeta 和 Variable::DifferentiableViewMeta
Variable grad_;
std::shared_ptr<Function> grad_fn_;
std::weak_ptr<Function> grad_accumulator_;
VariableVersion version_counter_;
std::vector<std::shared_ptr<FunctionPreHook>> hooks_;
bool requires_grad_;
bool is_view_;
uint32_t output_nr_;

# 仅 Variable::DifferentiableViewMeta
Variable base_;
uint32_t attr_version;
  • 1,grad_是另外一个Variable,存储当前Variable实例的梯度;
  • 2,grad_fn是个Function的实例,非leaf variables才有。通过Variable的grad_fn()来访问,实际上,PyTorch中就是通过是否grad_fn_ == nullptr来判断一个Variable是否是leaf variable的;
  • 3,grad_accumulator_是个Function的实例,只有leaf variables才有。通过Variable的grad_accumulator()来访问;
  • 4,version_counter_里有个version number;
  • 5,hooks_可以是一组;
  • 6,requires_grad_ 是个flag,表明此Variable实例是否需要grad;
  • 7,is_view_是个flag,表明此Variable实例是否是个view(没有实际存储,基于base的variable);
  • 8,output_nr_是个数字;
  • 9,base_是view的base variable;
  • 10,attr_version是个数字。

我们通过下面这一小段代码来演示下这个能力:

gemfield = torch.ones(2, 2, requires_grad=True)
syszux = gemfield + 2
civilnet = syszux * syszux * 3
gemfieldout = civilnet.mean()
gemfieldout.backward()

特别的,对于在python会话中的每一步操作,gemfield都将映射到内存上类实例中的成员/结构体的变化。

Tensor创建:gemfield = torch.ones(2, 2, requires_grad=True)

我们使用gemfield = torch.ones(2, 2, requires_grad=True) 语句来创建了一个tensor。在https://zhuanlan.zhihu.com/p/54896021一文中已经介绍过了,这个调用会在内存中产生如下一个Variable实例:

#gemfield

Variable实例 --> Variable::Imple实例 --> tensor data   --> TensorImpl实例 --> Storage实例 = [[1., 1.],[1., 1.]]
                                    --> autograd_meta --> grad_ (又一个Variable实例) = None
                                                      --> grad_fn_ (Function实例)= None
                                                      --> grad_accumulator_ (Function实例)= None
                                                      --> version_counter_ = 0
                                                      --> hooks_ len = 0
                                                      --> requires_grad_ = True
                                                      --> is_view_ = false
                                                      --> output_nr_ = 0
                                                      --> base_ = Not exist

这个gemfield变量就是图中的leaf,为什么呢?因为这是用户直接创建的(不是经过计算得到的),位于图中最“底端/外侧”的位置,没有子节点。这个时候,Tensor gemfield的grad是None,grad_fn是None。output_nr_为0,表明这个Variable是function的第1个输出。

Tensor的简单加法:syszux = gemfield + 2

我们使用 syszux = gemfield + 2 来得到一个新的Tensor,名字为syszux。这个加法嘛,在初始化的时候已经和C++中的THPVariable_add函数绑定上,并注册到Python的torch._C._TensorBase符号上了:

PyMethodDef variable_methods[] = {
    
  {
    "__add__", (PyCFunction)THPVariable_add, METH_VARARGS | METH_KEYWORDS, NULL},
  ......

而THPVariable_add的定义如下:

static PyObject * THPVariable_add(PyObject* self_, PyObject* args, PyObject* kwargs)
{
  ......
  return wrap(dispatch_add(self, r.tensor(0), r.scalar(1)));
}

1,scalar to tensor

在这个函数中,首先要将syszux = gemfield + 2 中的2从标量转换为tensor,这个转换逻辑如下:

auto tensor = scalar_to_tensor(scalar);
tensor.unsafeGetTensorImpl()->set_wrapped_number(true);
return autograd::make_variable(tensor);

现在scalar 2已经变成了内存中一个Variable的实例,在add真正执行之前,在内存中已经有2个Variable实例了,分别是gemfield和2:

#gemfield
Variable实例 --> Variable::Imple实例 --> tensor data   --> TensorImpl实例 --> Storage实例 = [[1., 1.],[1., 1.]]
                                    --> autograd_meta --> grad_ (又一个Variable实例) = None
                                                      --> grad_fn_ (Function实例)= None
                                                      --> grad_accumulator_ (Function实例)= None
                                                      --> version_counter_ = 0
                                                      --> hooks_ len = 0
                                                      --> requires_grad_ = True
                                                      --> is_view_ = false
                                                      --> output_nr_ = 0
                                                      --> base_ = Not exist

#scalar 2
Variable实例 --> Variable::Imple实例 --> tensor data   --> Ten
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值