pytorch中关于tensor的深入操作(持续更新)

tensor.clone()、tensor.detach()和tensor.data的作用和区别

参考链接列表:

  • https://blog.csdn.net/qq_37692302/article/details/107459525
  • https://blog.csdn.net/yagreenhand/article/details/104748886
  1. tensor.clone()

    返回tensor的拷贝,返回的新tensor和原来的tensor具有同样的大小和数据类型

    • 原tensor的requires_grad=True
    • clone()返回的tensor是中间节点,梯度会流向原tensor,即返回的tensor的梯度会叠加在原tensor上
    >>> import torch
    >>> a = torch.tensor(1.0, requires_grad=True)
    >>> b = a.clone()
    >>> id(a), id(b)  # a和b不是同一个对象
    (140191154302240, 140191145593424)
    >>> a.data_ptr(), b.data_ptr()  # 也不指向同一块内存地址
    (94724518544960, 94724519185792)
    >>> a.requires_grad, b.requires_grad  # 但b的requires_grad属性和a的一样,同样是True
    (True, True)
    >>> c = a * 2
    >>> c.backward()
    >>> a.grad
    tensor(2.)
    >>> d = b * 3
    >>> d.backward()
    >>> b.grad  # b的梯度值为None,因为是中间节点,梯度值不会被保存
    >>> a.grad  # b的梯度叠加在a上
    tensor(5.)
    
    • 原tensor的requires_grad=False
    >>> import torch
    >>> a = torch.tensor(1.0)
    >>> b = a.clone()
    >>> id(a), id(b)  # a和b不是同一个对象
    (140191169099168, 140191154762208)
    >>> a.data_ptr(), b.data_ptr()  # 也不指向同一块内存地址
    (94724519502912, 94724519533952)
    >>> a.requires_grad, b.requires_grad  # 但b的requires_grad属性和a的一样,同样是False
    (False, False)
    >>> b.requires_grad_()
    >>> c = b * 2
    >>> c.backward()
    >>> b.grad
    tensor(2.)
    >>> a.grad  # None
    
  2. tensor.detach()

    从计算图中脱离出来。

    返回一个新的tensor,新的tensor和原来的tensor共享数据内存,但不涉及梯度计算,即requires_grad=False。修改其中一个tensor的值,另一个也会改变,因为是共享同一块内存,但如果对其中一个tensor执行某些内置操作,则会报错,例如resize_、resize_as_、set_、transpose_。

    >>> import torch
    >>> a = torch.rand((3, 4), requires_grad=True)
    >>> b = a.detach()
    >>> id(a), id(b)  # a和b不是同一个对象了
    (140191157657504, 140191161442944)
    >>> a.data_ptr(), b.data_ptr()  # 但指向同一块内存地址
    (94724518609856, 94724518609856)
    >>> a.requires_grad, b.requires_grad  # b的requires_grad为False
    (True, False)
    >>> b[0][0] = 1
    >>> a[0][0]  # 修改b的值,a的值也会改变
    tensor(1., grad_fn=<SelectBackward>)
    >>> b.resize_((4, 3))  # 报错
    RuntimeError: set_sizes_contiguous is not allowed on a Tensor created from .data or .detach().
    

    补充在pytorch中停止梯度流的若干办法,避免不必要模块的参数更新

  3. tensor.clone().detach() 和 tensor.detach().clone()

    两者的结果是一样的,即返回的tensor和原tensor在梯度上或者数据上没有任何关系,一般用前者。

  4. tensor.data和tensor.detach()

    • 返回的tensor的相同点
      • 都和tensor共享同一块数据
      • 都和tensor的计算历史无关
      • requires_grad=False
    • 返回的tensor的不同点
      • y=x.data在某些情况下不安全,建议两者都可以使用的场景下使用tensor.detach()(tensor.data不能被autograd追踪求微分;tensor.detach()在反向传播时,能通in-place操作报告给autograd)
  5. tensor.detach用法

    如果我们有两个网络 A,B , 两个关系是这样的 y=A(x),z=B(y) 现在我们想用 z.backward() 来为 B 网络的参数来求梯度,但是又不想求 A 网络参数的梯度。
    #y=A(x), z=B(y) 求B中参数的梯度,不求A中参数的梯度
    #第一种方法
    y = A(x)
    z = B(y.detach())
    z.backward()

待更新

  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值