pytorch中的.clone() 和 .detach()复制变量、共享内存变量

前言:

最近博主搭建网络的时候需要操作loss函数,创建自己的loss函数不免遇到创建变量的时候发生,于是会遇到以下几个问题:

1.创建在CPU上的变量,无法与网络计算得到的GPU变量进行下一步操作。

2.直接从另外一个GPU变量copy过来的变量不知道是否会产生联动影响

经过查阅pytorch的手册发现两个工具函数:

.clone()  和 .detach()

一、.clone() vs .detach() 函数 

Torch 为了提高速度,向量或是矩阵的赋值是指向同一内存的,这不同于 Matlab。

如果需要新建一个变量开辟新的存储地址而不是引用,可以用 clone() 进行深拷贝,

首先我们来打印出来clone() 和 detach()操作后的数据类型定义变化:

(1). 简单打印类型

import torch

a = torch.tensor(1.0, requires_grad=True)
b = a.clone()
c = a.detach()

print(a) 
print(b)
print(c)

a.data *= 3
b += 1

print(a)
print(b)
print(c)

输出的结果如下:

上面可以分为两部分:

第一部分直接使用函数进行操作之后的结果,第二部分是加入一些运算之后的结果。 

其中:

clone操作在一定程度上可以视为是一个identity-mapping函数。grad_fn=<CloneBackward>,表示clone后的返回值是个中间变量,因此支持梯度的回溯。

detach()操作后的tensor与原始tensor共享数据内存,当原始tensor在计算图中数值发生反向传播等更新之后,detach()的tensor值也发生了改变

注意: 在pytorch中我们不要直接使用id是否相等来判断tensor是否共享内存,这只是充分条件,因为也许底层共享数据内存,但是仍然是新的tensor,比如detach(),如果我们直接打印id会出现以下情况。

import torch as t
a = t.tensor([1.0,2.0], requires_grad=True)
b = a.detach()
print(id(a))
print(id(b))

显然直接打印出来的id不等,但是它们确实是共享内存。

(2). clone()的梯度回传

detach()函数可以返回一个完全相同的tensor,与旧的tensor共享内存,脱离计算图,不会牵扯梯度计算。

而clone充当中间变量,会将梯度传给源张量进行叠加,但是本身不保存其grad,即值为None

import torch

a = torch.tensor(1.0, requires_grad=True)
a_ = a.clone()

y = a**2
z = a ** 2+a_ * 3

y.backward()
print(a.grad) # 2

z.backward()
print(a_.grad)# None. 中间variable,无grad
print(a.grad)

使用torch.clone()获得的新tensor和原来的数据不再共享内存,但仍保留在计算图中,clone操作在不共享数据内存的同时支持梯度梯度传递与叠加,所以常用在神经网络中某个单元需要重复使用的场景下。

通常如果原tensor的requires_grad=True,则:

  • clone()操作后的tensor requires_grad=True
  • detach()操作后的tensor requires_grad=False。

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

import torch

torch.manual_seed(0)

x= torch.tensor([1., 2.], requires_grad=True)

clone_x = x.clone()

detach_x = x.detach()

clone_detach_x = x.clone().detach()

f = torch.nn.Linear(2, 1)

y = f(x)

y.backward()

print(x.grad)

print(clone_x.requires_grad)

print(clone_x.grad)

print(detach_x.requires_grad)

print(clone_detach_x.requires_grad)

'''

输出结果如下:

tensor([-0.0053, 0.3793])

True

None

False

False

'''

另一个比较特殊的是当源张量的 require_grad=False,clone后的张量 require_grad=True,此时不存在张量回传现象,可以得到clone后的张量求导。

如下:

1

2

3

4

5

6

7

8

9

10

11

12

13

import torch

a = torch.tensor(1.0)

a_ = a.clone()

a_.requires_grad_() #require_grad=True

y = a_ ** 2

y.backward()

print(a.grad) # None

print(a_.grad)

'''

输出:

None

tensor(2.)

'''

了解了两者的区别后我们常与其他函数进行搭配使用,实现数据拷贝后的其他需要。

比如我们经常使用view()函数对tensor进行reshape操作。返回的新Tensor与源Tensor可能有不同的size,但是是共享data的,即其中的一个发生变化,另外一个也会跟着改变。

需要注意的是view返回的Tensor与源Tensor是共享data的,但是依然是一个新的Tensor(因为Tensor除了包含data外还有一些其他属性),两者id(内存地址)并不一致。

1

2

3

4

5

x = torch.rand(2, 2)

y = x.view(4)

x += 1

print(x)

print(y) # 也加了1

view() 仅仅是改变了对这个张量的观察角度,内部数据并未改变。这时候想返回一个真正新的副本(即不共享data内存)该怎么办呢?Pytorch还提供了一个reshape()可以改变形状,但是此函数并不能保证返回的是其拷贝,所以不推荐使用。推荐先用clone创造一个副本然后再使用view。参考此处

1

2

3

4

5

6

7

8

9

10

11

12

13

14

x = torch.rand(2, 2)

x_cp = x.clone().view(4)

x += 1

print(id(x))

print(id(x_cp))

print(x)

print(x_cp)

'''

140568935036464

140568935035816

tensor([[0.4963, 0.7682],

 [0.1320, 0.3074]])

tensor([[1.4963, 1.7682, 1.1320, 1.3074]])

'''

另外使用clone()会被记录在计算图中,即梯度回传到副本时也会传到源Tensor。

总结:

  • torch.detach() — 新的tensor会脱离计算图,不会牵扯梯度计算
  • torch.clone() — 新的tensor充当中间变量,会保留在计算图中,参与梯度计算(回传叠加),但是一般不会保留自身梯度。
    原地操作(in-place, such as resize_ / resize_as_ / set_ / transpose_) 在上面两者中执行都会引发错误或者警告。
  • 共享数据内存是底层设计,并不能简单的通过直接打印tensor的id地址进行判断,需要在进行赋值或运算操作后打印比较数据的变化进行判断。
  • 复制操作可以根据实际需要进行结合使用。

引用官方文档的话:如果你使用了in-place operation而没有报错的话,那么你可以确定你的梯度计算是正确的。另外尽量避免in-place的使用。

像y = x + y这样的运算会新开内存,然后将y指向新内存。我们可以使用Python自带的id函数进行验证:如果两个实例的ID相同,则它们所对应的内存地址相同。

该篇文章为转载:https://www.jb51.net/article/201734.htm

如有侵权请联系博主删除。

2024.04.16更新

import torch

# 创建原始数据
original_data = [1, 2, 3, 4, 5]
dataset = torch.tensor(original_data)


# 克隆数据到新变量
cloned_dataset = dataset.clone()
detached_dataset = dataset.detach()

# 修改数据
dataset[1]=999
dataset[-1]=888
# 打印原始数据和克隆数据
print("Original Dataset:", dataset.data)
print("Cloned Dataset:", cloned_dataset.data)
print("Detached Dataset:", detached_dataset.data)

输出:

Original Dataset: tensor([  1, 999,   3,   4, 888])
Cloned Dataset: tensor([1, 2, 3, 4, 5])
Detached Dataset: tensor([  1, 999,   3,   4, 888])

说明.clone()有类似copy.deep()的功能

  • 22
    点赞
  • 85
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值