深入理解pytorch中计算图的inplace操作

a=1
print(id(a))
a=2
print(id(a))

并不是在1的空间删除填上2,而是新开辟了空间。

a=[1]
print(id(a[0]))
a[0]=1
print(id(a[0]))

这个是Inplace操作。

embedding=nn.Parameter(torch.rand(2,3))
d=nn.Parameter(torch.rand(3,3))
user_embeddings=embedding.clone()
user_embedding_input = user_embeddings[0] 
a=user_embedding_input*3#option1
print(a)
a=torch.matmul(d,user_embedding_input)#option2
print(a)
user_embeddings[0]=a
loss=a.sum()
loss.backward()#是否报错?

报错。

这里涉及一个概念,你直接[0]这样索引,这种属于selectbackward。不会创建新的内存空间,类似的还有slicebackward(例如b[:2,:1]),其也不会创建新的内存空间。然后在后面又进行了赋值,这样,在计算d的梯度的时候显然会报错。

embedding=nn.Parameter(torch.rand(2,3))
d=nn.Parameter(torch.rand(3,3))
user_embeddings=embedding.clone()
user_embedding_input = user_embeddings[[0],:] 
# a=user_embedding_input*3#option1
# print(a)
a=torch.matmul(user_embedding_input,d)#option2
print(a)
user_embeddings[[0],:]=a
loss=a.sum()
loss.backward()#是否报错?

不报错,上面的索引是indexbackward,这个相当于创建了一个新的变量,然后index操作,梯度回传即可。虽然后面user_embeddings改了,但是那个属于中间节点,把user_embedding_input的梯度传过来即可,然后再传给前面的embedding,可以发现,user_embeddings改不改都没有关系。这并不会导致什么错误,而且反向传播之后会清空中间节点的梯度。(补充:indexbackward取出的时候会创建新变量,并和原来脱离关系,但是如果是要更新vv,则三种索引都会改变vv。vv[select]=1,vv[slice]=1,vv[index]=1这三者都会改变vv。这可能是pytorch出于方便考虑的。总之index和前两者只有在取出来的时候会不一样。)

embedding=nn.Parameter(torch.rand(2,3))
d=nn.Parameter(torch.rand(3,3))
user_embeddings=embedding.clone()
user_embedding_input = user_embeddings[[0],:] 
a=torch.matmul(user_embedding_input,d)
print(a)
user_embeddings[[0],:]=a
user_embedding_input=3#question line
loss=a.sum()
loss.backward()#是否报错?

不报错。

这里有人有疑问了,为什么user_embedding_input改了还是不报错,这是因为计算梯度有缓存,而且这个改也不是Inplace的,Pytorch已经缓存了那个原来的空间,所以不报错。

user_embedding_input[0]=2

如果你这么操作,那么就会报错了。

另外一个知识点,中间节点的赋值会连带上之前的计算图。

a=nn.Parameter(torch.tensor([[2.]]))#叶子节点。
b=a.clone()#中间节点。
print(b)
d=nn.Parameter(torch.tensor(3.))
print(d)
e=b[[0],:]*d
b[0]=e#赋值,会带上e的历史。而不仅仅是一个数据。
print(e)
loss1=e.sum()
e=b[[0],:]*d
b[0]=e
print(e)
loss1+=e.sum()
loss1.backward()
d.grad#14
a.grad#12

上面这样其实有点类似于RNN了,这个你能否计算对呢?

在这里插入图片描述
补充,上面画得有点不对,那个b节点也应该分裂成两个。我们假设b分成b1,b2,然后下面那两个b1,b2改名叫做bb1,bb2。

这里的难点在于,b1被Inplace了,讲道理b1是从a那里复制过来的,所以是cloneback。然后bb1是从b1那里index过来的。那b1的grad是怎么记录的呢?因为b1已经不存在了,缓存也失效了,因为直接被Inplace了。既然无法记录,如果继续反向传播到a节点呢?

补充

  1. 一个tensor不可导,对其部分进行赋值inplace一个可导的,整个tensor都会变得可导,也就是说pytorch里面计算梯度是以对象为单位的。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

音程

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值