pytorch使用backward()时出现'NoneType' object has no attribute 'zero_'的一种解决方法

笔者在学习使用pytorch的backward()时,发现了一个问题。原有代码是这样的

import torch
N,D_in,H,D_out = 10,1000,100,10
w1=torch.randn(D_in,H,requires_grad=True) # D_in * H
w2=torch.randn(H,D_out,requires_grad=True) # H * D_out 
ita=1e-6
x=torch.randn(N,D_in) # N * D_in 
y=torch.randn(N,D_out) # N * D_out

for t in range(500):

    #forward pass
    h=x.mm(w1) # N * H
    h_relu=h.clamp(min=0) # N * H
    y_pred=h_relu.mm(w2) # N * D_out

    #loss function
    loss=(y_pred-y).pow(2).sum()
    print(t,loss.item())
    #backward pass
    
    loss.backward()
       

    # update w1,w2

    w1=w1-ita*w1.grad
    w2=w2-ita*w2.grad
    w1.grad.zero_()
    w2.grad.zero_()
    

运行时出现错误:

'NoneType' object has no attribute 'zero_'

错误代码行为:w1.grad.zero_()
经分析,原因是:进行w1=w1-ita*w1.grad时,w1已由叶子节点变成了中间节点,而中间节点的grad会因为节约内存而被删除,所以w1的属性grad视为None

解决方法:在w1=w1-ita*w1.grad后添加代码w1.retain_grad()
对w2的处理同理。for内的代码块改为

for t in range(500):
#forward pass
h=x.mm(w1) # N * H
h_relu=h.clamp(min=0) # N * H
y_pred=h_relu.mm(w2) # N * D_out

#loss function
loss=(y_pred-y).pow(2).sum()
print(t,loss.item())
#backward pass

loss.backward()
   

# update w1,w2

w1=w1-ita*w1.grad
w1.retain_grad()
w2=w2-ita*w2.grad
w2.retain_grad()`

即可。此时程序能正常运行了。

而此时如果仍使用w1.grad.zero_()梯度清零,仍会报错:

'NoneType' object has no attribute 'zero_'

不知道为什么,求各位的解答…

  • 7
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 6
    评论
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值