lama-torch模型导出onnx
结论
导出失败,因为pytorch缺少对于torch.fft.rfft傅里叶变换算子导出onnx的支持,因此无法导出
torch导出onnx
def oexport_onnx(img: np.ndarray,
mask: np.ndarray,
config_p: str,
ckpt_p: str="./lama/configs/prediction/default.yaml",
mod=8,
device="cuda"):
predict_config = OmegaConf.load(config_p)
predict_config.model.path = ckpt_p
# device = torch.device(predict_config.device)
device = torch.device(device)
train_config_path = os.path.join(
predict_config.model.path, 'config.yaml')
with open(train_config_path, 'r') as f:
train_config = OmegaConf.create(yaml.safe_load(f))
train_config.training_model.predict_only = True
train_config.visualizer.kind = 'noop'
checkpoint_path = os.path.join(
predict_config.model.path, 'models',
predict_config.model.checkpoint
)
model = load_checkpoint(
train_config, checkpoint_path, strict=False, map_location='cpu')
model.eval().cpu()
dynamic_axes = {
'img' : {0 : 'batch_size', 2 : 'width', 3 : 'height'},
'mask' : {0 : 'batch_size', 2 : 'width', 3 : 'height'},
'output' : {0 : 'batch_size', 2 : 'width', 3 : 'height'},
}
dummy_inputs = {
"img": torch.randn((1, 3, 536, 800) ,dtype=torch.float),
"mask": torch.randint(low=0, high=1, size=(1, 1, 536, 800), dtype=torch.float),
}
output_names = ['out_put']
torch.onnx.export(model, # model being run
dummy_inputs, # model input (or a tuple for multiple inputs)
"lama_model.onnx", # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=10, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names = ['img', 'mask'],
output_names=output_names,
dynamic_axes=dynamic_axes
)
失败原因
aten::fft_rfftn' to ONNX opset version 10 is not supported
# lama\bin\saicinpainting\training\modules\ffc.py
# lama\saicinpainting\training\modules\ffc.py
无法解决,官方没有对于fft的支持。有一个解决办法是将该算子换成torch.atan(),但是效果据说不太好。
目前支持的算子
https://pytorch.org/docs/stable/onnx_supported_aten_ops.html
参考链接:
https://blog.csdn.net/xz1308579340/article/details/124908825
https://github.com/advimman/lama/issues/84