在 PyTorch 中,使用 while 循环时需要使用 torch.jit.trace_module() 或 torch.jit.script_method() 来手动跟踪模型并导出 ONNX。具体方法如下:
- 将模型转换为 torch.jit.ScriptModule
- 使用 torch.jit.trace_module() 跟踪模型并输入样本
- 使用 torch.onnx.export() 导出 ONNX 模型
例如:
import torch
import torch.onnx
class MyModule(torch.nn.Module):
def