这错是因为在forward前面加了注解@torch.jit.script_method
因为自己看到官方给是使用就是这样:
import torch
class MyModule(torch.jit.ScriptModule):
def __init__(self, N, M):
super(MyModule, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
@torch.jit.script_method
def forward(self, input):
if input.sum() > 0:
output = self.weight.mv(input)
else:
output = self.weight + input
return output
my_script_module = MyModule()
因为自己一直保存不了torchScrip,看到这篇就怀疑是自己forward里面带有参数的缘故
然后改了就报错了,因为自己的代码里是进行了回调,因为我的代码封装了几层,导致自己掉坑里了