基础环境:
ubuntu20.04
torch1.8
报错详情:
使用torch.jit.trace转换静态模型任务中,在模型完成转换进行save时出现报错:Could not export Python function call 'CheckpointFunction'. Remove calls to Python functions before export. Did you forget to add @script or @script_method annotation? If this is a nn.ModuleList, add it to __constants__:
报错分析:
由于torch.jit.trace是以静态追踪的方式来记录模型,其在模型单次向前传播时执行记录。而checkpoint.checkpoint的原理为:在模型进行向前传播的时候只存储一部分激活值,并且在反向传播需要的时候重新计算其余值,具有动态性,该方法的目的在于降低内存占用。torch.jit.trace无法运行动态性函数,因此会出现上述报错。
解决措施:
不使用checkpoint.checkpoint方法,将其注释,直接运行向前传播过程