遇到问题
最近在复现运行代码时候遇到以下错误:
ImportError: cannot import name ‘zero_gradients’ from ‘torch.autograd.gradcheck’ (C:\Users\zsh\anaconda3\envs\zshpytorch\lib\site-packages\torch\autograd\gradcheck.py)
解决方法
删掉以下代码
from torch.autograd.gradcheck import zero_gradients
修改替换为
def zero_gradients(x):
if isinstance(x, torch.Tensor):
if x.grad is not None:
x.grad.detach_()
x.grad.zero_()
elif isinstance(x, collections.abc.Iterable):
for elem in x:
zero_gradients(elem)