pytorch中clone()、copy_()、detach()等函数辨析

1、clone()

创建一个tensor与源tensor具有相同的shape、dtype和device,两个tensor不共享内存,但是梯度会叠加在源tensor上。

import torch

a = torch.tensor([1.,2.,3.],requires_grad=True)
b = a.clone()
print(a.data_ptr()) # 2194950282880
print(b.data_ptr()) # 2194950282688

print(a) # tensor([1., 2., 3.], requires_grad=True)
print(b) # tensor([1., 2., 3.], grad_fn=<CloneBackward0>)
c = 2*a
d = 3*b
c.sum().backward()
d.sum().backward()
print(a.grad) # tensor([5., 5., 5.])
print(b.grad) # None,由于b是通过clone运算得到的,不是叶子节点

2、copy_()

与clone类似,调用方法有差异。同样不共享内存,梯度会叠加到源tensor上。

import torch

a = torch.tensor([1.,2.,3.],requires_grad=True,device='cuda')
b = torch.empty_like(a).copy_(a)
print(a) # tensor([1., 2., 3.], device='cuda:0', requires_grad=True)
print(b) # tensor([1., 2., 3.], device='cuda:0', grad_fn=<CopyBackwards>)
print(a.data_ptr()) # 30123491328
print(b.data_ptr()) # 30123491840
c= 2*a
d= 3*b
c.sum().backward()
d.sum().backward()
print(a.grad) # tensor([5., 5., 5.], device='cuda:0')

clone()与copy_()的不同之处在于 copy_()是 in-place operation,可以用于给需要求梯度的tensor初始化。

3、detach()

新的tensor与源tensor共享内存,但是新tensor的 requires_grad = Flase,从计算图中分离出来了,新tensor与源tensor之间不会流通梯度信息。

import torch

a = torch.tensor([1.,2.,3.],requires_grad=True,device='cuda')
b = a.detach()
print(a.data_ptr()) #43008393216
print(b.data_ptr()) #43008393216

print(a) # tensor([1., 2., 3.], device='cuda:0', requires_grad=True)
print(b) # tensor([1., 2., 3.], device='cuda:0')

c = 2*a 
d = 3*b
c.sum().backward()
print(a.grad) # tensor([2., 2., 2.], device='cuda:0')
d.sum().backward() # 报错
print(a.grad)

detach() 与 .data类似,同样是共享内存,并且不记录梯度信息。但 detach() 比 .data更加安全

4、参数初始化

在pytorch中,有两种情况不能使用 inplace operation:

  1. requires_grad = True 的叶子tensor
  2. 求梯度阶段需要用到的tensor

因此,如果直接对一个requires_grad = True 的叶子 tensor 做 in-place operation 会报错

import torch
a = torch.tensor([1.,2.,3.,4.])
w = torch.zeros(4)
w.requires_grad = True 
w.copy_(a) # RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.

如果想要进行参数初始化,可以使用 .data 或者 detach()

import torch

a = torch.tensor([1.,2.,3.,4.])
w = torch.zeros(4)
w.requires_grad = True
print(w) # tensor([0., 0., 0., 0.], requires_grad=True)
w.data.copy_(a) #或者 w.detach().copy_(a)
print(w) # tensor([1., 2., 3., 4.], requires_grad=True)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值