torch.jit.frontend.UnsupportedNodeError: function definitions aren't supported
class MyModule(torch.jit.ScriptModule):
__constants__ = ['mods']
def __init__(self):
super(MyModule, self).__init__()
self.mods = torch.nn.ModuleList([SubModule() for i in range(10)])
错误写法:
@torch.jit.script_method
def forward(self, x, targets=None):
def _branch(_embedding, _in):
for i, e in enumerate(_embedding):
_in = e(_in)
if i == 4:
out_branch = _in
return _in, out_branch
# backbone
x2, x1, x0 = self.backbone(x.cuda())
# yolo branch 0
out0, out0_branch = _branch(self.embedding0, x0)
正确写法:
@torch.jit.script_method
def forward(self, v):
for module in self.mods:
v = module(v)
return v
上面哪个错误原因:
代码对齐,应该和def __init__(self) 函数左对齐
解决方法2:
可以通过注释掉@torch.jit.script这个语句