PyTorch 中的 while 语句在导出 ONNX 时需要使用特殊的函数进行处理,这个函数叫 torch.ops.script_ops.while_loop。该函数接受三个参数:循环条件、循环体和循环初始值。
示例:
import torch
import torch.onnx
def my_loop(counter, threshold):
while counter < threshold:
counter += 1
return counter
counter = torch.tensor(0, dtype=torch.flo