在 PyTorch 中,导出 ONNX 模型时,如果在模型中有 if 语句,则需要使用 torch.jit.trace() 函数将该部分转化为固定的计算图。
例如:
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.fc = nn.Linear(10, 10)
def forward(self, x):
if x.sum() > 0: