【Pytorch】对比clone、detach以及copy_等张量复制操作

pytorch提供了clonedetachcopy_new_tensor等多种张量的复制操作,尤其前两者在深度学习的网络架构中经常被使用,本文旨在对比这些操作的差别。

1. clone

返回一个和源张量同shapedtypedevice的张量,与源张量不共享数据内存,但提供梯度的回溯

下面,通过例子来详细说明:

示例

(1)定义

import torch

a = torch.tensor(1.0, requires_grad=True, device="cuda", dtype=torch.float64)

a_ = a.clone()
print(a_)   # tensor(1., device='cuda:0', dtype=torch.float64, grad_fn=<CloneBackward>)

注意grad_fn=<CloneBackward>,说明clone后的返回值是个中间variable,因此支持梯度的回溯。因此,clone操作在一定程度上可以视为是一个identity-mapping函数。

(2)梯度的回溯

clone作为一个中间variable,会将梯度传给源张量进行叠加。

import torch

a = torch.tensor(1.0, requires_grad=True)
y = a ** 2 
a_ = a.clone()
z = a_ * 3
y.backward()
print(a.grad)   # 2
z.backward()
print(a_.grad)   # None. 中间variable,无grad
print(a.grad)    # 5. a_的梯度会传递回给a,因此2+3=5

但若源张量的require_grad=False,而clone后的张量require_grad=True,显然此时不存在张量回溯现象,clone后的张量可以求导。

import torch

a = torch.tensor(1.0)
a_ = a.clone()
a_.requires_grad_()

y = a_ ** 2
y.backward()
print(a.grad)   # None
print(a_.grad)   # 2.  可得到导数

(3)张量数据非共享

import torch

a = torch.tensor(1.0, requires_grad=True)
a_ = a.clone()

a.data *= 3
a_ += 1

print(a)   # tensor(3., requires_grad=True)
print(a_)  # tensor(2., grad_fn=<AddBackward0>).  注意grad_fn的变化

综上论述,clone操作在不共享数据内存的同时支持梯度回溯,所以常用在神经网络中某个单元需要重复使用的场景下。

2. detach

detach的机制则与clone完全不同,即返回一个和源张量同shapedtypedevice的张量,与源张量共享数据内存,但不提供梯度计算,即requires_grad=False,因此脱离计算图。

同样,通过例子来详细说明:

(1)定义

import torch

a = torch.tensor(1.0, requires_grad=True, device="cuda", dtype=torch.float64)

a_ = a.detach()
print(a_)   # tensor(1., device='cuda:0', dtype=torch.float64)

(2)脱离原计算图

import torch

a = torch.tensor(1.0, requires_grad=True)
y = a ** 2 
a_ = a.detach()
print(a_.grad)    # None,requires_grad=False
a_.requires_grad_()  # 强制其requires_grad=True,从而支持求导

z = a_ * 3
y.backward()
z.backward()

print(a.grad)    # 2,与a_无关系
print(a_.grad)   #

可见,detach后的张量,即使重新定义requires_grad=True,也与源张量的梯度没有关系。

(3)共享张量数据内存

import torch

a = torch.tensor(1.0, requires_grad=True)
a_ = a.detach()

print(a)    # tensor(1., requires_grad=True)
print(a_)   # tensor(1.)

a_ += 1   
print(a)     # tensor(2., requires_grad=True)
print(a_)    # tensor(2.)

a.data *= 2
print(a)    # tensor(4., requires_grad=True)
print(a_)    # tensor(4.)

综上论述,detach操作在共享数据内存的脱离计算图,所以常用在神经网络中仅要利用张量数值,而不需要追踪导数的场景下。

3. clone和detach联合使用

clone提供了非数据共享的梯度追溯功能,而detach又“舍弃”了梯度功能,因此clonedetach意味着着只做简单的数据复制,既不数据共享,也不对梯度共享,从此两个张量无关联。

置于是先clone还是先detach,其返回值一样,一般采用tensor.clone().detach()

4. new_tensor

new_tensor可以将源张量中的数据复制到目标张量(数据不共享),同时提供了更细致的devicedtyperequires_grad属性控制:

new_tensor(data, dtype=None, device=None, requires_grad=False) 

注意:其默认参数下的操作等同于.clone().detach(),而requires_grad=True时的效果相当于.clone().detach()requires_grad_(True)。上面两种情况都推荐使用后者。

5. copy_

copy_同样将源张量中的数据复制到目标张量(数据不共享),其devicedtyperequires_grad一般都保留目标张量的设定,仅仅进行数据复制,同时其支持broadcast操作。

a = torch.tensor([[1,2,3], [4,5,6]], device="cuda")
b = torch.tensor([7.0,8.0,9.0], requires_grad=True)
a.copy_(b)
print(a)   # tensor([[7, 8, 9], [7, 8, 9]], device='cuda:0')  

【Ref】:

  1. 关于 pytorch inplace operation, 需要知道的几件事
  • 84
    点赞
  • 189
    收藏
    觉得还不错? 一键收藏
  • 7
    评论
PyTorch中,clone函数用于创建一个张量的副本。通过clone函数创建的副本张量与原始张量具有相同的值,但是它们是不同的张量对象。这意味着副本张量可以独立于原始张量进行操作,而不会影响原始张量梯度计算。 根据引用\[1\]和\[2\]的例子,当使用clone函数创建副本张量时,副本张量会保留梯度信息,可以进行梯度回溯。这意味着副本张量梯度会传递给原始张量,并与原始张量梯度进行叠加。 然而,根据引用\[3\]的例子,clone函数本身不会保存副本张量梯度信息。当对副本张量进行梯度计算时,副本张量梯度为None。只有原始张量会保存梯度信息。 因此,总结来说,clone函数在PyTorch中用于创建一个张量的副本,副本张量与原始张量具有相同的值,但是它们是不同的张量对象。副本张量可以进行独立的操作,但是在梯度计算中,副本张量梯度会传递给原始张量进行叠加,而副本张量本身不保存梯度信息。 #### 引用[.reference_title] - *1* *2* [【python基础】PyTorchclone()、detach()](https://blog.csdn.net/dujuancao11/article/details/121563226)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control,239^v3^insert_chatgpt"}} ] [.reference_item] - *3* [pytorch中的.clone() 和 .detach()复制变量、共享内存变量](https://blog.csdn.net/weixin_44503976/article/details/126631909)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值