错误:
今天遇到了此错误:
ImportError: cannot import name ‘zero_gradients’ from ‘torch.autograd.gradcheck’ (/home/ubuntu/anaconda3/envs/OTLA/ges/torch/autograd/gradcheck.py)
解决方法:
将~/anaconda3/envs/d2l/lib/python3.7/site-packages/advertorch/attacks/fast_adaptive_boundary.py
中的from torch.autograd.gradcheck import zero_gradients
删掉,加入
def zero_gradients(x):
if isinstance(x, torch.Tensor):
if x.grad is not None:
x.grad.detach_()
x.grad.zero_()
elif isinstance(x, collections.abc.Iterable):
for elem in x:
zero_gradients(elem)
参考文章:https://blog.csdn.net/weixin_43633501/article/details/136007959