报错完整信息:RuntimeError: Legacy autograd function with non-static forward method is deprecated. Please use new-style autograd function with static forward method.
意思就是由于你当前的pytorch版本过高,而原代码的版本较低。如果pytorch版本高于1.3会出现该问题。当前版本要求forward过程是静态的,所以需要将原代码进行修改。
网上的一些做法是降低pytorch的版本,显然不是一个很好的解决方法,通过查询相关资料发现解决这个问题很简单,如下代码所示:
class Exp(Function):
@staticmethod
def forward(ctx, i):
result = i.exp()
ctx.save_for_backward(result)
return result
@staticmethod
def backward(ctx, grad_output):
result, = ctx.saved_tensors
return grad_output * result
即只需要在forward()和backward()前面加上@staticmethod即可运行通过。