PyTorch Tensor类:clone和detach的用法

本来是准备分析一下Tensor类的源码的,但是看了看发现这个类的源码实现基本都是在C++上,目前精力有限,所以就算了。现在打算分析一下Tensor中可能比较难用的方法,比如clone,detach。

这些方法之所以难用主要还是因为Tensor支持自动微分,也就是说每个Tensor不止能表示这个Tensor对应的值,还可以表示以这个Tensor为根结点的前向计算图。

Clone方法

我们先看PyTorch的官方文档torch.clone — PyTorch 1.10.0 documentation

 很多人可能对clone是可微的这句话不是很懂,其实就是论文里偶尔能见到的identify操作,我这里举个例子。

import torch
x = torch.tensor([1.0],requires_grad = True)
y = x.clone()
y.backward()
print("x.grad:",x.grad)
"""
输出结果:
x.grad: tensor([1.])
"""

可以看到我们对y进行反向传播,x的梯度为1。所以实际上y = x.clone()类似于数学表达式中的y=x,但是在python里如果让y=x,是不会给y创建新的内存空间的,这就需要clone了。

或者也可以y = x+0和y=x*1。

import torch
x = torch.tensor([1.0],requires_grad = True)
y = x+0
y.backward()
print("x.grad:",x.grad)
"""
输出结果:
x.grad: tensor([1.])
"""

可以看的,和clone的效果是一致的。

下面以计算图的形式来表述:

(当前计算图中结点代表操作(叶子节点除外),边代表张量)

Detach方法

detach也是一个和计算图关联比较紧密的tensor方法。

还是先看一下PyTorch的官方文档torch.Tensor.detach — PyTorch 1.10.0 documentation

Returns a new Tensor, detached from the current graph.The result will never require gradient.Returned Tensor shares the same storage with the original one. In-place modifications on either of them will be seen, and may trigger errors in correctness checks.

也就是说detach的效果其实就是将一个Tensor的requires_grad置为False。那么一个Tensor的requires_grad为False对这个计算图会有什么影响呢?从效果上来看其实就是以该Tensor为根节点的计算子图在反向传播中不再会得到梯度。

这里继续看一个例子。

import torch
#L = (X+Y)×Z
#dL/dX = Z, dL/dY = Z, dL/dZ = X+Y
X = torch.tensor([1.0],requires_grad = True)
Y = torch.tensor([2.0],requires_grad = True)
Z = torch.tensor([3.0],requires_grad = True)
K = X+Y
L = K*Z
L.backward()
print("X.grad:",X.grad)
print("Y.grad:",Y.grad)
print("Z.grad:",Z.grad)
"""
输出结果:
X.grad: tensor([3.])
Y.grad: tensor([3.])
Z.grad: tensor([3.])
"""

这是一个正常反向传播的情况下的梯度值,接下来我将利用detach使梯度无法反传到X和Y。

import torch
#L = (X+Y)×Z
#dL/dX = Z, dL/dY = Z, dL/dZ = X+Y
X = torch.tensor([1.0],requires_grad = True)
Y = torch.tensor([2.0],requires_grad = True)
Z = torch.tensor([3.0],requires_grad = True)
K = X+Y
L = K.detach()*Z
L.backward()
print("X.grad:",X.grad)
print("Y.grad:",Y.grad)
print("Z.grad:",Z.grad)
"""
输出结果:
X.grad: None
Y.grad: None
Z.grad: tensor([3.])
"""

从计算图上看是这样的(虚线代表无法到达)

 也就是说由于K.detach,导致当前计算中K的requires_grad为False,因此dL/dK这个梯度无法传播到K,从而导致dL/dX和dL/dY都无法被计算。

但是需要注意的是K.detach()并不会修改K本身的requires_grad属性,因此K本身还是可以接收梯度的。

下面是另一个例子:

import torch
#L = (X+Y)×3×Z
#dL/dX = 3Z, dL/dY = 3Z, dL/dZ = 3(X+Y)
X = torch.tensor([1.0],requires_grad = True)
Y = torch.tensor([2.0],requires_grad = True)
Z = torch.tensor([3.0],requires_grad = True)
K = X+Y
L = K.detach()*Z*K
L.backward()
print("X.grad:",X.grad)
print("Y.grad:",Y.grad)
print("Z.grad:",Z.grad)
"""
输出结果:
X.grad: tensor([9.])
Y.grad: tensor([9.])
Z.grad: tensor([9.])
"""

这里我在计算L的时候同时用了K.detach()和K,K.detach带来的梯度虽然无法回传,但是K的梯度是可以回传的,所以X和Y依然存在梯度。

所以如果我们想让从K中出去的所有计算都无法得到梯度,那么应该先使用K = K.detach()。

所以这么来看,detach这个方法其实可以理解为在本次计算中让某个变量变为常数。

总结

目前对于clone和detach的原理分析的比较清楚了,但是具体的使用场景还需要在实践中继续观察,后续看到了我也会补充上来的。

  • 3
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值